Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- multimodal/examples/albef/configs/retrieval.yaml +73 -0
- multimodal/examples/albef/configs/vqa.yaml +78 -0
- multimodal/examples/albef/data/__init__.py +5 -0
- multimodal/examples/albef/data/retrieval_datamodule.py +188 -0
- multimodal/examples/albef/data/retrieval_dataset.py +149 -0
- multimodal/examples/albef/data/transforms.py +141 -0
- multimodal/examples/albef/data/vqa_datamodules.py +206 -0
- multimodal/examples/albef/data/vqa_dataset.py +115 -0
- multimodal/examples/common/data/__init__.py +7 -0
- multimodal/examples/common/data/multidata.py +194 -0
- multimodal/examples/flava/callbacks/__init__.py +7 -0
- multimodal/examples/flava/callbacks/multimodal_eval.py +108 -0
- multimodal/examples/flava/configs/finetuning/qnli.yaml +48 -0
- multimodal/examples/flava/configs/finetuning/rendered_sst2.yaml +37 -0
- multimodal/examples/flava/configs/pretraining/debug.yaml +61 -0
- multimodal/examples/flava/data/__init__.py +10 -0
- multimodal/examples/flava/data/datamodules.py +529 -0
- multimodal/examples/flava/data/imagenet_zeroshot_data.py +1095 -0
- multimodal/examples/flava/data/transforms.py +131 -0
- multimodal/examples/flava/data/utils.py +80 -0
- multimodal/examples/flava/native/README.md +43 -0
- multimodal/examples/flava/native/__init__.py +5 -0
- multimodal/examples/flava/native/configs/1.8b.yaml +79 -0
- multimodal/examples/flava/native/configs/10b.yaml +80 -0
- multimodal/examples/flava/native/configs/2.7b.yaml +79 -0
- multimodal/examples/flava/native/configs/4.8b.yaml +79 -0
- multimodal/examples/flava/native/configs/900m.yaml +79 -0
- multimodal/examples/flava/native/configs/pretrain_debug.yaml +63 -0
- multimodal/examples/flava/native/data.py +560 -0
- multimodal/examples/flava/native/model.py +78 -0
- multimodal/examples/flava/native/train.py +415 -0
- multimodal/examples/flava/native/utils.py +160 -0
- multimodal/examples/flava/notebooks/RemapFLAVACheckpoint.ipynb +172 -0
- multimodal/examples/flava/tools/convert_weights.py +72 -0
- multimodal/examples/mugen/data/README.md +10 -0
- multimodal/examples/mugen/data/coinrun/construct_from_json.py +756 -0
- multimodal/examples/mugen/data/coinrun/game.py +295 -0
- multimodal/examples/mugen/data/coinrun/generate_text_desc.py +435 -0
- multimodal/examples/mugen/data/mugen_datamodules.py +112 -0
- multimodal/examples/mugen/generation/LoadAndComparePretrainedVQVAE.ipynb +383 -0
- multimodal/examples/mugen/generation/README.md +33 -0
- multimodal/examples/mugen/generation/text_video_gpt.py +260 -0
- multimodal/examples/mugen/generation/video_vqvae.py +113 -0
- multimodal/examples/mugen/retrieval/README.md +34 -0
- multimodal/examples/mugen/retrieval/configs/eval.yaml +48 -0
- multimodal/examples/mugen/retrieval/configs/train.yaml +53 -0
- multimodal/examples/mugen/retrieval/definitions.py +105 -0
- multimodal/examples/mugen/retrieval/eval.py +54 -0
- multimodal/examples/mugen/retrieval/model.py +145 -0
- multimodal/examples/mugen/retrieval/train.py +67 -0
multimodal/examples/albef/configs/retrieval.yaml
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
hidden_size: &hidden_size 768
|
| 2 |
+
vocab_size: &vocab_size 30522
|
| 3 |
+
type_vocab_size: &type_vocab_size 2
|
| 4 |
+
max_position_embeddings: &max_position_embeddings 512
|
| 5 |
+
pad_token_id: &pad_token_id 0
|
| 6 |
+
embed_size: &embed_size 256
|
| 7 |
+
|
| 8 |
+
seed: 42
|
| 9 |
+
world_size: 1
|
| 10 |
+
device: "cuda"
|
| 11 |
+
dist_url: "env://"
|
| 12 |
+
output_path: "./examples/albef/outputs/retrieval_output.pt"
|
| 13 |
+
|
| 14 |
+
datamodule_args:
|
| 15 |
+
train_files: ["./examples/albef/data_files/coco_train.json"]
|
| 16 |
+
test_files: ["./examples/albef/data_files/coco_test.json"]
|
| 17 |
+
image_root: "./examples/albef/data_files/coco"
|
| 18 |
+
batch_size: 32
|
| 19 |
+
num_workers: 8
|
| 20 |
+
|
| 21 |
+
vision_encoder_args:
|
| 22 |
+
hidden_size: *hidden_size
|
| 23 |
+
image_size: 384
|
| 24 |
+
patch_size: 16
|
| 25 |
+
num_hidden_layers: 12
|
| 26 |
+
num_attention_heads: 12
|
| 27 |
+
mlp_dim: 3072
|
| 28 |
+
dropout: 0.0
|
| 29 |
+
attention_dropout: 0.0
|
| 30 |
+
layer_norm_eps: 1e-6
|
| 31 |
+
|
| 32 |
+
text_encoder_args:
|
| 33 |
+
vocab_size: *vocab_size
|
| 34 |
+
hidden_size: *hidden_size
|
| 35 |
+
type_vocab_size: *type_vocab_size
|
| 36 |
+
max_position_embeddings: *max_position_embeddings
|
| 37 |
+
pad_token_id: *pad_token_id
|
| 38 |
+
num_hidden_layers: 6
|
| 39 |
+
num_attention_heads: 12
|
| 40 |
+
intermediate_size: 3072
|
| 41 |
+
layer_norm_eps: 1e-12
|
| 42 |
+
dropout: 0.0
|
| 43 |
+
|
| 44 |
+
multimodal_encoder_args:
|
| 45 |
+
hidden_size: *hidden_size
|
| 46 |
+
num_hidden_layers: 6
|
| 47 |
+
num_attention_heads: 12
|
| 48 |
+
intermediate_size: 3072
|
| 49 |
+
layer_norm_eps: 1e-12
|
| 50 |
+
|
| 51 |
+
projection_args:
|
| 52 |
+
in_features: *hidden_size
|
| 53 |
+
out_features: *embed_size
|
| 54 |
+
|
| 55 |
+
similarity_args:
|
| 56 |
+
embed_size: *embed_size
|
| 57 |
+
queue_size: 65536
|
| 58 |
+
temp: 0.07
|
| 59 |
+
|
| 60 |
+
training_args:
|
| 61 |
+
log_every_n_steps: 100
|
| 62 |
+
alpha: 0.4
|
| 63 |
+
weight_decay: 0.02
|
| 64 |
+
lr: 1e-5
|
| 65 |
+
min_lr: 1e-6
|
| 66 |
+
max_epochs: 5
|
| 67 |
+
step_size: 100
|
| 68 |
+
warmup_steps: 1
|
| 69 |
+
checkpoint_root: "./examples/albef/checkpoints"
|
| 70 |
+
|
| 71 |
+
eval_args:
|
| 72 |
+
log_every_n_steps: 100
|
| 73 |
+
k_test: 256
|
multimodal/examples/albef/configs/vqa.yaml
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
hidden_size: &hidden_size 768
|
| 2 |
+
vocab_size: &vocab_size 30522
|
| 3 |
+
type_vocab_size: &type_vocab_size 2
|
| 4 |
+
max_position_embeddings: &max_position_embeddings 512
|
| 5 |
+
pad_token_id: &pad_token_id 0
|
| 6 |
+
|
| 7 |
+
seed: 42
|
| 8 |
+
world_size: 1
|
| 9 |
+
device: "cuda"
|
| 10 |
+
dist_url: "env://"
|
| 11 |
+
output_root: "./examples/albef/outputs"
|
| 12 |
+
|
| 13 |
+
datamodule_args:
|
| 14 |
+
train_files: ["./examples/albef/data_files/vqa_train.json", "./examples/albef/data_files/vg_qa.json", "./examples/albef/data_files/vqa_val.json"]
|
| 15 |
+
test_files: ["./examples/albef/data_files/vqa_test.json"]
|
| 16 |
+
answer_list: "./examples/albef/data_files/answer_list.json"
|
| 17 |
+
vqa_root: "./examples/albef/data_files/coco"
|
| 18 |
+
vg_root: "./examples/albef/data_files/visual_genome"
|
| 19 |
+
batch_size: 32
|
| 20 |
+
num_workers: 8
|
| 21 |
+
|
| 22 |
+
vision_encoder_args:
|
| 23 |
+
hidden_size: *hidden_size
|
| 24 |
+
image_size: 384
|
| 25 |
+
patch_size: 16
|
| 26 |
+
num_hidden_layers: 12
|
| 27 |
+
num_attention_heads: 12
|
| 28 |
+
mlp_dim: 3072
|
| 29 |
+
dropout: 0.0
|
| 30 |
+
attention_dropout: 0.0
|
| 31 |
+
layer_norm_eps: 1e-6
|
| 32 |
+
|
| 33 |
+
text_encoder_args:
|
| 34 |
+
vocab_size: *vocab_size
|
| 35 |
+
hidden_size: *hidden_size
|
| 36 |
+
type_vocab_size: *type_vocab_size
|
| 37 |
+
max_position_embeddings: *max_position_embeddings
|
| 38 |
+
pad_token_id: *pad_token_id
|
| 39 |
+
num_hidden_layers: 6
|
| 40 |
+
num_attention_heads: 12
|
| 41 |
+
intermediate_size: 3072
|
| 42 |
+
layer_norm_eps: 1e-12
|
| 43 |
+
dropout: 0.0
|
| 44 |
+
|
| 45 |
+
multimodal_encoder_args:
|
| 46 |
+
hidden_size: *hidden_size
|
| 47 |
+
num_hidden_layers: 6
|
| 48 |
+
num_attention_heads: 12
|
| 49 |
+
intermediate_size: 3072
|
| 50 |
+
layer_norm_eps: 1e-12
|
| 51 |
+
|
| 52 |
+
text_embeddings_args:
|
| 53 |
+
hidden_size: *hidden_size
|
| 54 |
+
vocab_size: *vocab_size
|
| 55 |
+
pad_token_id: *pad_token_id
|
| 56 |
+
max_position_embeddings: *max_position_embeddings
|
| 57 |
+
type_vocab_size: *type_vocab_size
|
| 58 |
+
layer_norm_eps: 1e-12
|
| 59 |
+
|
| 60 |
+
prediction_head_args:
|
| 61 |
+
hidden_size: *hidden_size
|
| 62 |
+
vocab_size: *vocab_size
|
| 63 |
+
layer_norm_eps: 1e-12
|
| 64 |
+
|
| 65 |
+
training_args:
|
| 66 |
+
log_every_n_steps: 100
|
| 67 |
+
alpha: 0.4
|
| 68 |
+
weight_decay: 0.02
|
| 69 |
+
lr: 2e-5
|
| 70 |
+
min_lr: 1e-6
|
| 71 |
+
max_epochs: 8
|
| 72 |
+
step_size: 100
|
| 73 |
+
warmup_steps: 4
|
| 74 |
+
checkpoint_root: "./examples/albef/checkpoints"
|
| 75 |
+
|
| 76 |
+
eval_args:
|
| 77 |
+
log_every_n_steps: 100
|
| 78 |
+
k_test: 128
|
multimodal/examples/albef/data/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
multimodal/examples/albef/data/retrieval_datamodule.py
ADDED
|
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from typing import List, Optional, Tuple
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from data.retrieval_dataset import (
|
| 11 |
+
ImageToTextRetrievalDataset,
|
| 12 |
+
RetrievalTrainingDataset,
|
| 13 |
+
TextToImageRetrievalDataset,
|
| 14 |
+
)
|
| 15 |
+
from data.transforms import (
|
| 16 |
+
ALBEFTextTransform,
|
| 17 |
+
testing_image_transform,
|
| 18 |
+
training_image_transform,
|
| 19 |
+
)
|
| 20 |
+
from pytorch_lightning import LightningDataModule
|
| 21 |
+
from torch import Tensor
|
| 22 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 23 |
+
from torch.utils.data import DataLoader, Dataset, DistributedSampler
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class RetrievalDataModule(LightningDataModule):
|
| 27 |
+
"""
|
| 28 |
+
The Data Module for Retrieval task.
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
train_files (List[str]): The paths to training json files.
|
| 32 |
+
test_files (List[str]): The paths to testing json files.
|
| 33 |
+
image_root (str): The path to image data directory.
|
| 34 |
+
batch_size (int): The sampling batch size.
|
| 35 |
+
num_workers (int): The number of workers for the distributed mode.
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
def __init__(
|
| 39 |
+
self,
|
| 40 |
+
train_files: List[str],
|
| 41 |
+
test_files: List[str],
|
| 42 |
+
image_root: str,
|
| 43 |
+
batch_size: int,
|
| 44 |
+
num_workers: int,
|
| 45 |
+
) -> None:
|
| 46 |
+
super().__init__()
|
| 47 |
+
self.train_dataset = RetrievalTrainingDataset(
|
| 48 |
+
train_files,
|
| 49 |
+
image_root,
|
| 50 |
+
training_image_transform(),
|
| 51 |
+
ALBEFTextTransform(truncate=True, max_seq_len=30, add_end_token=False),
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
self.image_dataset = ImageToTextRetrievalDataset(
|
| 55 |
+
test_files,
|
| 56 |
+
image_root,
|
| 57 |
+
testing_image_transform(),
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
self.text_dataset = TextToImageRetrievalDataset(
|
| 61 |
+
test_files,
|
| 62 |
+
ALBEFTextTransform(
|
| 63 |
+
truncate=True,
|
| 64 |
+
pad_to_max_seq_len=True,
|
| 65 |
+
max_seq_len=30,
|
| 66 |
+
add_end_token=False,
|
| 67 |
+
),
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
self.batch_size = batch_size
|
| 71 |
+
self.num_workers = num_workers
|
| 72 |
+
|
| 73 |
+
def _get_sampler(
|
| 74 |
+
self,
|
| 75 |
+
dataset: Dataset,
|
| 76 |
+
shuffle: bool,
|
| 77 |
+
is_distributed: bool,
|
| 78 |
+
num_tasks: int,
|
| 79 |
+
global_rank: int,
|
| 80 |
+
) -> Optional[DistributedSampler]:
|
| 81 |
+
# do not return a sampler if is not in distributed mode
|
| 82 |
+
# a default RandomSampler is used in this case
|
| 83 |
+
if not is_distributed:
|
| 84 |
+
return None
|
| 85 |
+
|
| 86 |
+
return DistributedSampler(
|
| 87 |
+
dataset, num_replicas=num_tasks, rank=global_rank, shuffle=shuffle
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
def train_dataloader(
|
| 91 |
+
self,
|
| 92 |
+
is_distributed: bool = False,
|
| 93 |
+
num_tasks: int = 0,
|
| 94 |
+
global_rank: int = 0,
|
| 95 |
+
drop_last: bool = True,
|
| 96 |
+
) -> DataLoader:
|
| 97 |
+
"""
|
| 98 |
+
DataLoader Outputs:
|
| 99 |
+
images (Tensor): Tensor of shape (B, C, W, H) of image inputs.
|
| 100 |
+
text (Tensor): Tensor of shape (B, L) of text inputs.
|
| 101 |
+
text_atts (Tensor): Tensor of shape (B, L) of text attention mask.
|
| 102 |
+
idx (Tensor): Tensor of shape (B) of image identifiers.
|
| 103 |
+
"""
|
| 104 |
+
sampler = self._get_sampler(
|
| 105 |
+
dataset=self.train_dataset,
|
| 106 |
+
shuffle=True,
|
| 107 |
+
is_distributed=is_distributed,
|
| 108 |
+
num_tasks=num_tasks,
|
| 109 |
+
global_rank=global_rank,
|
| 110 |
+
)
|
| 111 |
+
shuffle = sampler is None
|
| 112 |
+
return DataLoader(
|
| 113 |
+
self.train_dataset,
|
| 114 |
+
batch_size=self.batch_size,
|
| 115 |
+
num_workers=self.num_workers,
|
| 116 |
+
pin_memory=True,
|
| 117 |
+
sampler=sampler,
|
| 118 |
+
shuffle=shuffle,
|
| 119 |
+
collate_fn=retrieval_train_collate_fn,
|
| 120 |
+
drop_last=drop_last,
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
def image_dataloader(
|
| 124 |
+
self,
|
| 125 |
+
drop_last: bool = False,
|
| 126 |
+
) -> DataLoader:
|
| 127 |
+
"""
|
| 128 |
+
DataLoader Outputs:
|
| 129 |
+
images (Tensor): Tensor of shape (B, C, W, H) of image inputs.
|
| 130 |
+
"""
|
| 131 |
+
return DataLoader(
|
| 132 |
+
self.image_dataset,
|
| 133 |
+
batch_size=self.batch_size,
|
| 134 |
+
num_workers=self.num_workers,
|
| 135 |
+
pin_memory=True,
|
| 136 |
+
sampler=None,
|
| 137 |
+
shuffle=False,
|
| 138 |
+
collate_fn=None,
|
| 139 |
+
drop_last=drop_last,
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
def text_dataloader(
|
| 143 |
+
self,
|
| 144 |
+
drop_last: bool = False,
|
| 145 |
+
) -> DataLoader:
|
| 146 |
+
"""
|
| 147 |
+
DataLoader Outputs:
|
| 148 |
+
text (Tensor): Tensor of shape (B, L) of text inputs.
|
| 149 |
+
text_atts (Tensor): Tensor of shape (B, L) of text attention mask.
|
| 150 |
+
"""
|
| 151 |
+
return DataLoader(
|
| 152 |
+
self.text_dataset,
|
| 153 |
+
batch_size=self.batch_size,
|
| 154 |
+
num_workers=self.num_workers,
|
| 155 |
+
pin_memory=True,
|
| 156 |
+
sampler=None,
|
| 157 |
+
shuffle=False,
|
| 158 |
+
collate_fn=text_collate_fn,
|
| 159 |
+
drop_last=drop_last,
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def retrieval_train_collate_fn(
|
| 164 |
+
batch: List[Tuple[Tensor, Tensor, int]],
|
| 165 |
+
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
|
| 166 |
+
image_list = []
|
| 167 |
+
text_list = []
|
| 168 |
+
idx_list = []
|
| 169 |
+
for image, text, idx in batch:
|
| 170 |
+
image_list.append(image)
|
| 171 |
+
text_list.append(text)
|
| 172 |
+
idx_list.append(idx)
|
| 173 |
+
images = torch.stack(image_list, dim=0)
|
| 174 |
+
text = pad_sequence(text_list, batch_first=True)
|
| 175 |
+
text_atts = (text != 0).type(torch.long)
|
| 176 |
+
idx = Tensor(idx_list).type(torch.long)
|
| 177 |
+
return (
|
| 178 |
+
images,
|
| 179 |
+
text,
|
| 180 |
+
text_atts,
|
| 181 |
+
idx,
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def text_collate_fn(batch: List[Tensor]) -> Tuple[Tensor, Tensor]:
|
| 186 |
+
text = pad_sequence(batch, batch_first=True)
|
| 187 |
+
text_atts = (text != 0).type(torch.long)
|
| 188 |
+
return text, text_atts
|
multimodal/examples/albef/data/retrieval_dataset.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import json
|
| 8 |
+
import os
|
| 9 |
+
from typing import Callable, List, Tuple, Union
|
| 10 |
+
|
| 11 |
+
from PIL import Image
|
| 12 |
+
from torch import Tensor
|
| 13 |
+
from torch.utils.data import Dataset
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class RetrievalTrainingDataset(Dataset):
|
| 17 |
+
"""
|
| 18 |
+
Create the training dataset for Retrieval task.
|
| 19 |
+
|
| 20 |
+
Args:
|
| 21 |
+
ann_file (List[str]): The paths to training annotation json files.
|
| 22 |
+
image_root (str): The path to image data directory.
|
| 23 |
+
image_transform (Callable[[Image.Image], Tensor]): Image data transform.
|
| 24 |
+
text_transform (Callable[[Union[List[str], str]], Tensor]): Text data transform.
|
| 25 |
+
|
| 26 |
+
Dataset Outputs:
|
| 27 |
+
image (Tensor): Transformed image input tensor of shape (C, H, W).
|
| 28 |
+
caption (Tensor): Transformed text token input ids.
|
| 29 |
+
idx (int): The unique identifier for the image.
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
def __init__(
|
| 33 |
+
self,
|
| 34 |
+
ann_file: List[str],
|
| 35 |
+
image_root: str,
|
| 36 |
+
image_transform: Callable[[Image.Image], Tensor],
|
| 37 |
+
text_transform: Callable[[Union[List[str], str]], Tensor],
|
| 38 |
+
) -> None:
|
| 39 |
+
self.ann = []
|
| 40 |
+
for f in ann_file:
|
| 41 |
+
self.ann += json.load(open(f, "r"))
|
| 42 |
+
|
| 43 |
+
self.image_root = image_root
|
| 44 |
+
self.image_transform = image_transform
|
| 45 |
+
self.text_transform = text_transform
|
| 46 |
+
|
| 47 |
+
self.idx = {} # map str image_id from dataset to int ids
|
| 48 |
+
i = 0
|
| 49 |
+
for ann in self.ann:
|
| 50 |
+
image_id = ann["image_id"]
|
| 51 |
+
if image_id not in self.idx.keys():
|
| 52 |
+
self.idx[image_id] = i
|
| 53 |
+
i += 1
|
| 54 |
+
|
| 55 |
+
def __len__(self) -> int:
|
| 56 |
+
return len(self.ann)
|
| 57 |
+
|
| 58 |
+
def __getitem__(self, index: int) -> Tuple[Tensor, Tensor, int]:
|
| 59 |
+
ann = self.ann[index]
|
| 60 |
+
image_path = os.path.join(self.image_root, ann["image"])
|
| 61 |
+
image = Image.open(image_path).convert("RGB")
|
| 62 |
+
image = self.image_transform(image)
|
| 63 |
+
caption = self.text_transform(ann["caption"])
|
| 64 |
+
return image, caption, self.idx[ann["image_id"]]
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class ImageToTextRetrievalDataset(Dataset):
|
| 68 |
+
"""
|
| 69 |
+
Create the dataset for Image-to-Text Retrieval task.
|
| 70 |
+
|
| 71 |
+
Args:
|
| 72 |
+
ann_file (List[str]): The paths to annotation json files.
|
| 73 |
+
image_root (str): The path to image data directory.
|
| 74 |
+
image_transform (Callable[[Image.Image], Tensor]): Image data transform.
|
| 75 |
+
|
| 76 |
+
Dataset Outputs:
|
| 77 |
+
image (Tensor): Transformed image input tensor of shape (C, H, W).
|
| 78 |
+
"""
|
| 79 |
+
|
| 80 |
+
def __init__(
|
| 81 |
+
self,
|
| 82 |
+
ann_file: List[str],
|
| 83 |
+
image_root: str,
|
| 84 |
+
image_transform: Callable[[Image.Image], Tensor],
|
| 85 |
+
) -> None:
|
| 86 |
+
self.image_root = image_root
|
| 87 |
+
self.image_transform = image_transform
|
| 88 |
+
|
| 89 |
+
self.ann = []
|
| 90 |
+
self.images = [] # paths to all images in the dataset
|
| 91 |
+
self.image_to_text = {} # map image ids to text ids for evaluation
|
| 92 |
+
for f in ann_file:
|
| 93 |
+
self.ann += json.load(open(f, "r"))
|
| 94 |
+
|
| 95 |
+
text_id = 0
|
| 96 |
+
for image_id, ann in enumerate(self.ann):
|
| 97 |
+
self.images.append(ann["image"])
|
| 98 |
+
num_text = len(ann["caption"])
|
| 99 |
+
self.image_to_text[image_id] = list(range(text_id, text_id + num_text))
|
| 100 |
+
text_id += num_text
|
| 101 |
+
|
| 102 |
+
def __len__(self) -> int:
|
| 103 |
+
return len(self.images)
|
| 104 |
+
|
| 105 |
+
def __getitem__(self, index: int) -> Tensor:
|
| 106 |
+
image_path = os.path.join(self.image_root, self.images[index])
|
| 107 |
+
image = Image.open(image_path).convert("RGB")
|
| 108 |
+
image = self.image_transform(image)
|
| 109 |
+
return image
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
class TextToImageRetrievalDataset(Dataset):
|
| 113 |
+
"""
|
| 114 |
+
Create the dataset for Text-to-Image Retrieval task.
|
| 115 |
+
|
| 116 |
+
Args:
|
| 117 |
+
ann_file (List[str]): The paths to annotation json files.
|
| 118 |
+
text_transform (Callable[[Union[List[str], str]], Tensor]): Text data transform.
|
| 119 |
+
|
| 120 |
+
Dataset Outputs:
|
| 121 |
+
text (Tensor): Transformed text token input ids.
|
| 122 |
+
"""
|
| 123 |
+
|
| 124 |
+
def __init__(
|
| 125 |
+
self,
|
| 126 |
+
ann_file: List[str],
|
| 127 |
+
text_transform: Callable[[Union[List[str], str]], Tensor],
|
| 128 |
+
) -> None:
|
| 129 |
+
self.text_transform = text_transform
|
| 130 |
+
|
| 131 |
+
self.ann = []
|
| 132 |
+
self.text = [] # all text strings in the dataset
|
| 133 |
+
self.text_to_image = {} # map text ids to image ids for evaluation
|
| 134 |
+
for f in ann_file:
|
| 135 |
+
self.ann += json.load(open(f, "r"))
|
| 136 |
+
|
| 137 |
+
text_id = 0
|
| 138 |
+
for image_id, ann in enumerate(self.ann):
|
| 139 |
+
for caption in ann["caption"]:
|
| 140 |
+
self.text.append(caption)
|
| 141 |
+
self.text_to_image[text_id] = image_id
|
| 142 |
+
text_id += 1
|
| 143 |
+
|
| 144 |
+
def __len__(self) -> int:
|
| 145 |
+
return len(self.text)
|
| 146 |
+
|
| 147 |
+
def __getitem__(self, index: int) -> Tensor:
|
| 148 |
+
text = self.text_transform(self.text[index])
|
| 149 |
+
return text
|
multimodal/examples/albef/data/transforms.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import re
|
| 8 |
+
from typing import List, Tuple, Union
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
from torchtext.transforms import PadTransform, Sequential, ToTensor, Truncate
|
| 13 |
+
from torchvision import transforms
|
| 14 |
+
from transformers.models.bert.tokenization_bert import BertTokenizer
|
| 15 |
+
|
| 16 |
+
# mean and standard deviation from the ALBEF repo:
|
| 17 |
+
# https://github.com/salesforce/ALBEF/blob/main/dataset/__init__.py#L16
|
| 18 |
+
MEAN = (0.48145466, 0.4578275, 0.40821073)
|
| 19 |
+
STD_DEV = (0.26862954, 0.26130258, 0.27577711)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class ALBEFTextTransform:
|
| 23 |
+
"""
|
| 24 |
+
Remove punctuations and trailing spaces in input text and transform it into
|
| 25 |
+
a Tensor of token ids using BERTTokenizer.
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
pretrained_tokenizer (str): Pretrained tokenizer to use.
|
| 29 |
+
Default: "bert-base-uncased"
|
| 30 |
+
do_pre_process (bool): Whether to pre-process input text.
|
| 31 |
+
Defaults to True.
|
| 32 |
+
truncate (bool): Whether to truncate input text to max_seq_length.
|
| 33 |
+
Defaults to False.
|
| 34 |
+
pad_to_max_seq_len (bool): Whether to pad the sequence to max_seq_length.
|
| 35 |
+
add_end_token (bool): Whether to add the end-of-sentence token.
|
| 36 |
+
Defaults to True.
|
| 37 |
+
max_seq_len (int): The max sequence length after truncating or padding.
|
| 38 |
+
Defaults to 25.
|
| 39 |
+
cls_token_id (int): Value to represent the start of each text.
|
| 40 |
+
Defaults to 101, Hugging Face's BERT cls token id.
|
| 41 |
+
sep_token_id (int): Value to represent the end of each text.
|
| 42 |
+
Defaults to 102, Hugging Face's BERT sep token id.
|
| 43 |
+
pad_token_id (int): Value with which to pad each text so that all texts are the same length.
|
| 44 |
+
Defaults to 0, Hugging Face's BERT pad token id.
|
| 45 |
+
|
| 46 |
+
Inputs:
|
| 47 |
+
text (Union[List[str], str]): Input text to transform.
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
def __init__(
|
| 51 |
+
self,
|
| 52 |
+
pretrained_tokenizer: str = "bert-base-uncased",
|
| 53 |
+
do_pre_process: bool = True,
|
| 54 |
+
truncate: bool = False,
|
| 55 |
+
pad_to_max_seq_len: bool = False,
|
| 56 |
+
add_end_token: bool = True,
|
| 57 |
+
max_seq_len: int = 25,
|
| 58 |
+
cls_token_id: int = 101,
|
| 59 |
+
sep_token_id: int = 102,
|
| 60 |
+
pad_token_id: int = 0,
|
| 61 |
+
):
|
| 62 |
+
self.do_pre_process = do_pre_process
|
| 63 |
+
self.cls_token_id = cls_token_id
|
| 64 |
+
self.sep_token_id = sep_token_id
|
| 65 |
+
self.pad_token_id = pad_token_id
|
| 66 |
+
self.add_end_token = add_end_token
|
| 67 |
+
|
| 68 |
+
self.tokenizer = BertTokenizer.from_pretrained(pretrained_tokenizer)
|
| 69 |
+
self.transform = Sequential(
|
| 70 |
+
Truncate(max_seq_len=max_seq_len) if truncate else torch.nn.Identity(),
|
| 71 |
+
ToTensor(padding_value=self.pad_token_id),
|
| 72 |
+
(
|
| 73 |
+
PadTransform(max_length=max_seq_len, pad_value=self.pad_token_id)
|
| 74 |
+
if pad_to_max_seq_len
|
| 75 |
+
else torch.nn.Identity()
|
| 76 |
+
),
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
def pre_process(self, text: str) -> str:
|
| 80 |
+
text = (
|
| 81 |
+
re.sub(
|
| 82 |
+
r"([,.'!?\"()*#:;~])",
|
| 83 |
+
"",
|
| 84 |
+
text,
|
| 85 |
+
)
|
| 86 |
+
.replace("-", " ")
|
| 87 |
+
.replace("/", " ")
|
| 88 |
+
)
|
| 89 |
+
text = text.rstrip(" ")
|
| 90 |
+
|
| 91 |
+
return text
|
| 92 |
+
|
| 93 |
+
def __call__(self, text: Union[List[str], str]) -> torch.Tensor:
|
| 94 |
+
if self.do_pre_process:
|
| 95 |
+
if isinstance(text, str):
|
| 96 |
+
text = self.pre_process(text)
|
| 97 |
+
else:
|
| 98 |
+
text = [self.pre_process(t) for t in text]
|
| 99 |
+
tokens = self.tokenizer(text)["input_ids"]
|
| 100 |
+
if not self.add_end_token and tokens[-1] == self.sep_token_id:
|
| 101 |
+
tokens = tokens[:-1]
|
| 102 |
+
input_ids = self.transform(tokens)
|
| 103 |
+
|
| 104 |
+
return input_ids
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def training_image_transform(
|
| 108 |
+
image_size: int = 384,
|
| 109 |
+
scale: Tuple[float, float] = (0.5, 1.0),
|
| 110 |
+
image_interpolation=transforms.InterpolationMode.BICUBIC,
|
| 111 |
+
mean: Tuple[float, float, float] = MEAN,
|
| 112 |
+
std_dev: Tuple[float, float, float] = STD_DEV,
|
| 113 |
+
) -> transforms.Compose:
|
| 114 |
+
return transforms.Compose(
|
| 115 |
+
[
|
| 116 |
+
transforms.RandomResizedCrop(
|
| 117 |
+
image_size, scale=scale, interpolation=image_interpolation
|
| 118 |
+
),
|
| 119 |
+
transforms.RandomHorizontalFlip(),
|
| 120 |
+
transforms.RandAugment(2, 7),
|
| 121 |
+
transforms.ToTensor(),
|
| 122 |
+
transforms.Normalize(mean, std_dev),
|
| 123 |
+
]
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def testing_image_transform(
|
| 128 |
+
image_size: int = 384,
|
| 129 |
+
image_interpolation=transforms.InterpolationMode.BICUBIC,
|
| 130 |
+
mean: Tuple[float, float, float] = MEAN,
|
| 131 |
+
std_dev: Tuple[float, float, float] = STD_DEV,
|
| 132 |
+
) -> transforms.Compose:
|
| 133 |
+
return transforms.Compose(
|
| 134 |
+
[
|
| 135 |
+
transforms.Resize(
|
| 136 |
+
(image_size, image_size), interpolation=image_interpolation
|
| 137 |
+
),
|
| 138 |
+
transforms.ToTensor(),
|
| 139 |
+
transforms.Normalize(mean, std_dev),
|
| 140 |
+
]
|
| 141 |
+
)
|
multimodal/examples/albef/data/vqa_datamodules.py
ADDED
|
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from typing import List, Optional, Tuple
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from data.transforms import (
|
| 11 |
+
ALBEFTextTransform,
|
| 12 |
+
testing_image_transform,
|
| 13 |
+
training_image_transform,
|
| 14 |
+
)
|
| 15 |
+
from data.vqa_dataset import VQADataset
|
| 16 |
+
from pytorch_lightning import LightningDataModule
|
| 17 |
+
from torch import Tensor
|
| 18 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 19 |
+
from torch.utils.data import DataLoader, DistributedSampler
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class VQADataModule(LightningDataModule):
|
| 23 |
+
"""
|
| 24 |
+
The Data Module for Visual Question Answering task.
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
train_files (List[str]): The paths to training json files.
|
| 28 |
+
test_files (List[str]): The paths to testing json files.
|
| 29 |
+
answer_list (str): The path to the answers list.
|
| 30 |
+
vqa_root (str): The path to vqa data directory.
|
| 31 |
+
vg_root (str): The path to vg data directory.
|
| 32 |
+
batch_size (int): The sampling batch size.
|
| 33 |
+
num_workers (int): The number of workers for the distributed mode.
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
def __init__(
|
| 37 |
+
self,
|
| 38 |
+
train_files: List[str],
|
| 39 |
+
test_files: List[str],
|
| 40 |
+
answer_list: str,
|
| 41 |
+
vqa_root: str,
|
| 42 |
+
vg_root: str,
|
| 43 |
+
batch_size: int,
|
| 44 |
+
num_workers: int,
|
| 45 |
+
) -> None:
|
| 46 |
+
super().__init__()
|
| 47 |
+
self.train_dataset = VQADataset(
|
| 48 |
+
train_files,
|
| 49 |
+
vqa_root,
|
| 50 |
+
vg_root,
|
| 51 |
+
image_transform=training_image_transform(),
|
| 52 |
+
question_transform=ALBEFTextTransform(
|
| 53 |
+
truncate=True, max_seq_len=25, add_end_token=False
|
| 54 |
+
),
|
| 55 |
+
answer_transform=ALBEFTextTransform(do_pre_process=False),
|
| 56 |
+
split="train",
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
self.test_dataset = VQADataset(
|
| 60 |
+
test_files,
|
| 61 |
+
vqa_root,
|
| 62 |
+
vg_root,
|
| 63 |
+
image_transform=testing_image_transform(),
|
| 64 |
+
question_transform=ALBEFTextTransform(add_end_token=False),
|
| 65 |
+
answer_transform=ALBEFTextTransform(do_pre_process=False),
|
| 66 |
+
split="test",
|
| 67 |
+
answer_list=answer_list,
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
self.batch_size = batch_size
|
| 71 |
+
self.num_workers = num_workers
|
| 72 |
+
|
| 73 |
+
def _get_sampler(
|
| 74 |
+
self,
|
| 75 |
+
dataset: VQADataset,
|
| 76 |
+
shuffle: bool,
|
| 77 |
+
is_distributed: bool,
|
| 78 |
+
num_tasks: int,
|
| 79 |
+
global_rank: int,
|
| 80 |
+
) -> Optional[DistributedSampler]:
|
| 81 |
+
if not is_distributed:
|
| 82 |
+
return None
|
| 83 |
+
|
| 84 |
+
return DistributedSampler(
|
| 85 |
+
dataset, num_replicas=num_tasks, rank=global_rank, shuffle=shuffle
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
def train_dataloader(
|
| 89 |
+
self,
|
| 90 |
+
is_distributed: bool = False,
|
| 91 |
+
num_tasks: int = 0,
|
| 92 |
+
global_rank: int = 0,
|
| 93 |
+
drop_last: bool = True,
|
| 94 |
+
) -> DataLoader:
|
| 95 |
+
"""
|
| 96 |
+
DataLoader Outputs:
|
| 97 |
+
images (Tensor): Tensor of shape (B, C, W, H) of image inputs.
|
| 98 |
+
questions (Tensor): Tensor of shape (B, L) of question inputs.
|
| 99 |
+
question_atts (Tensor): Tensor of shape (B, L) of question attention mask.
|
| 100 |
+
answers (Tensor): Tensor of shape (N, M) of answer inputs.
|
| 101 |
+
N >= B because a vqa sample can have multiple answers.
|
| 102 |
+
answer_atts (Tensor): Tensor of shape (N, M) of answer attention mask.
|
| 103 |
+
weights (Tensor): Tensor of shape (N) of answer weights.
|
| 104 |
+
ans_lengths (List[int]): List of length B and sum N where
|
| 105 |
+
ans_lengths[i] = number of answers for images[i] and questions[i].
|
| 106 |
+
"""
|
| 107 |
+
sampler = self._get_sampler(
|
| 108 |
+
dataset=self.train_dataset,
|
| 109 |
+
shuffle=True,
|
| 110 |
+
is_distributed=is_distributed,
|
| 111 |
+
num_tasks=num_tasks,
|
| 112 |
+
global_rank=global_rank,
|
| 113 |
+
)
|
| 114 |
+
shuffle = sampler is None
|
| 115 |
+
return DataLoader(
|
| 116 |
+
self.train_dataset,
|
| 117 |
+
batch_size=self.batch_size,
|
| 118 |
+
num_workers=self.num_workers,
|
| 119 |
+
pin_memory=True,
|
| 120 |
+
sampler=sampler,
|
| 121 |
+
shuffle=shuffle,
|
| 122 |
+
collate_fn=vqa_train_collate_fn,
|
| 123 |
+
drop_last=drop_last,
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
def test_dataloader(
|
| 127 |
+
self,
|
| 128 |
+
is_distributed: bool = False,
|
| 129 |
+
num_tasks: int = 0,
|
| 130 |
+
global_rank: int = 0,
|
| 131 |
+
drop_last=False,
|
| 132 |
+
) -> DataLoader:
|
| 133 |
+
"""
|
| 134 |
+
DataLoader Outputs:
|
| 135 |
+
images (Tensor): Tensor of shape (B, C, W, H) of image inputs.
|
| 136 |
+
questions (Tensor): Tensor of shape (B, L) of question inputs.
|
| 137 |
+
question_atts (Tensor): Tensor of shape (B, L) of question attention mask.
|
| 138 |
+
question_ids (List): List of length B of question ids.
|
| 139 |
+
"""
|
| 140 |
+
sampler = self._get_sampler(
|
| 141 |
+
dataset=self.test_dataset,
|
| 142 |
+
shuffle=False,
|
| 143 |
+
is_distributed=is_distributed,
|
| 144 |
+
num_tasks=num_tasks,
|
| 145 |
+
global_rank=global_rank,
|
| 146 |
+
)
|
| 147 |
+
return DataLoader(
|
| 148 |
+
self.test_dataset,
|
| 149 |
+
batch_size=self.batch_size,
|
| 150 |
+
num_workers=self.num_workers,
|
| 151 |
+
pin_memory=True,
|
| 152 |
+
sampler=sampler,
|
| 153 |
+
shuffle=False,
|
| 154 |
+
collate_fn=vqa_test_collate_fn,
|
| 155 |
+
drop_last=drop_last,
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def vqa_train_collate_fn(
|
| 160 |
+
batch: List[Tuple[Tensor, Tensor, List[Tensor], List[float]]],
|
| 161 |
+
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, List[int]]:
|
| 162 |
+
image_list = []
|
| 163 |
+
question_list = []
|
| 164 |
+
answer_list = []
|
| 165 |
+
weight_list = []
|
| 166 |
+
ans_lengths = []
|
| 167 |
+
for image, question, answer, weights in batch:
|
| 168 |
+
image_list.append(image)
|
| 169 |
+
question_list.append(question)
|
| 170 |
+
answer_list += answer
|
| 171 |
+
weight_list += weights
|
| 172 |
+
ans_lengths.append(len(answer))
|
| 173 |
+
images = torch.stack(image_list, dim=0)
|
| 174 |
+
questions = pad_sequence(question_list, batch_first=True)
|
| 175 |
+
question_atts = (questions != 0).type(torch.long)
|
| 176 |
+
answers = pad_sequence(answer_list, batch_first=True)
|
| 177 |
+
answer_atts = (answers != 0).type(torch.long)
|
| 178 |
+
weights = torch.Tensor(weight_list)
|
| 179 |
+
return (
|
| 180 |
+
images,
|
| 181 |
+
questions,
|
| 182 |
+
question_atts,
|
| 183 |
+
answers,
|
| 184 |
+
answer_atts,
|
| 185 |
+
weights,
|
| 186 |
+
ans_lengths,
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def vqa_test_collate_fn(
|
| 191 |
+
batch: List[Tuple[Tensor, Tensor, int]],
|
| 192 |
+
) -> Tuple[Tensor, Tensor, Tensor, List[int]]:
|
| 193 |
+
image_list, question_list, question_ids = [], [], []
|
| 194 |
+
for image, question, question_id in batch:
|
| 195 |
+
image_list.append(image)
|
| 196 |
+
question_list.append(question)
|
| 197 |
+
question_ids.append(question_id)
|
| 198 |
+
images = torch.stack(image_list, dim=0)
|
| 199 |
+
questions = pad_sequence(question_list, batch_first=True)
|
| 200 |
+
question_atts = (questions != 0).type(torch.long)
|
| 201 |
+
return (
|
| 202 |
+
images,
|
| 203 |
+
questions,
|
| 204 |
+
question_atts,
|
| 205 |
+
question_ids,
|
| 206 |
+
)
|
multimodal/examples/albef/data/vqa_dataset.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import json
|
| 8 |
+
import os
|
| 9 |
+
from typing import Callable, List, Tuple, Union
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
|
| 13 |
+
from PIL import Image
|
| 14 |
+
from torch import Tensor
|
| 15 |
+
from torch.utils.data import Dataset
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class VQADataset(Dataset):
|
| 19 |
+
"""
|
| 20 |
+
Create the dataset for VQA task.
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
ann_file (List[str]): The paths to annotation json files.
|
| 24 |
+
vqa_root (str): The path to vqa data directory.
|
| 25 |
+
vg_root (str): The path to vg data directory.
|
| 26 |
+
image_transform (Callable[[Image.Image], Tensor]): image data transform.
|
| 27 |
+
question_transform (Callable[[Union[List[str], str]], Tensor]): text data transform for questions.
|
| 28 |
+
answer_transform (Callable[[Union[List[str], str]], Tensor]): text data transform for answers.
|
| 29 |
+
split (str): Indicates train or test. Default is train.
|
| 30 |
+
answer_list (str): The path to the answers list. Required for test split.
|
| 31 |
+
|
| 32 |
+
Dataset Outputs:
|
| 33 |
+
if split is train:
|
| 34 |
+
image (Tensor): Transformed image input tensor of shape (C, W, H).
|
| 35 |
+
question (Tensor): Transformed question token input ids.
|
| 36 |
+
answers (List[Tensor]): List of transformed answers token input ids.
|
| 37 |
+
answer_weights (List[float]): List of answer weights.
|
| 38 |
+
answer_weights[i] is proportional to the number of occurences of answers[i]
|
| 39 |
+
if split is test:
|
| 40 |
+
image (Tensor): Transformed image input tensor of shape (C, W, H).
|
| 41 |
+
question (Tensor): Transformed text token input ids.
|
| 42 |
+
question_id (int): The question sample id.
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
def __init__(
|
| 46 |
+
self,
|
| 47 |
+
ann_file: List[str],
|
| 48 |
+
vqa_root: str,
|
| 49 |
+
vg_root: str,
|
| 50 |
+
image_transform: Callable[[Image.Image], Tensor],
|
| 51 |
+
question_transform: Callable[[Union[List[str], str]], Tensor],
|
| 52 |
+
answer_transform: Callable[[Union[List[str], str]], Tensor],
|
| 53 |
+
split: str = "train",
|
| 54 |
+
answer_list: str = None,
|
| 55 |
+
) -> None:
|
| 56 |
+
self.ann = []
|
| 57 |
+
for f in ann_file:
|
| 58 |
+
self.ann += json.load(open(f, "r"))
|
| 59 |
+
|
| 60 |
+
self.vqa_root = vqa_root
|
| 61 |
+
self.vg_root = vg_root
|
| 62 |
+
self.image_transform = image_transform
|
| 63 |
+
self.question_transform = question_transform
|
| 64 |
+
self.answer_transform = answer_transform
|
| 65 |
+
self.split = split
|
| 66 |
+
|
| 67 |
+
if split == "test":
|
| 68 |
+
self.answer_list = json.load(open(answer_list, "r"))
|
| 69 |
+
self.answer_input_ids = self.answer_transform(self.answer_list)
|
| 70 |
+
self.answer_attention_mask = (self.answer_input_ids != 0).type(torch.long)
|
| 71 |
+
|
| 72 |
+
def __len__(self) -> int:
|
| 73 |
+
return len(self.ann)
|
| 74 |
+
|
| 75 |
+
def __getitem__(
|
| 76 |
+
self, index: int
|
| 77 |
+
) -> Union[
|
| 78 |
+
Tuple[Tensor, Tensor, int], Tuple[Tensor, Tensor, List[Tensor], List[float]]
|
| 79 |
+
]:
|
| 80 |
+
ann = self.ann[index]
|
| 81 |
+
|
| 82 |
+
image_root = self.vqa_root if ann["dataset"] == "vqa" else self.vg_root
|
| 83 |
+
image_path = os.path.join(image_root, ann["image"])
|
| 84 |
+
image = Image.open(image_path).convert("RGB")
|
| 85 |
+
image = self.image_transform(image)
|
| 86 |
+
question = self.question_transform(ann["question"])
|
| 87 |
+
|
| 88 |
+
if self.split == "test":
|
| 89 |
+
return image, question, ann["question_id"]
|
| 90 |
+
|
| 91 |
+
elif self.split == "train":
|
| 92 |
+
if ann["dataset"] == "vqa":
|
| 93 |
+
# Each VQA sample question has a list of answers (with potential repeats)
|
| 94 |
+
# answer_weight[answer] = count(answer) / len(answers for the question)
|
| 95 |
+
answer_weights = {}
|
| 96 |
+
for answer in ann["answer"]:
|
| 97 |
+
if answer in answer_weights.keys():
|
| 98 |
+
answer_weights[answer] += 1 / len(ann["answer"])
|
| 99 |
+
else:
|
| 100 |
+
answer_weights[answer] = 1 / len(ann["answer"])
|
| 101 |
+
|
| 102 |
+
answers = list(answer_weights.keys())
|
| 103 |
+
answer_weights = list(answer_weights.values())
|
| 104 |
+
|
| 105 |
+
elif ann["dataset"] == "vg":
|
| 106 |
+
# A VG sample question has one answer so assign it a constant weight (0.5)
|
| 107 |
+
answers = [ann["answer"]]
|
| 108 |
+
answer_weights = [0.5]
|
| 109 |
+
|
| 110 |
+
answers = list(self.answer_transform(answers))
|
| 111 |
+
|
| 112 |
+
return image, question, answers, answer_weights
|
| 113 |
+
|
| 114 |
+
else:
|
| 115 |
+
raise ValueError("dataset split should be train or test")
|
multimodal/examples/common/data/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from .multidata import * # noqa F401
|
multimodal/examples/common/data/multidata.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import random
|
| 8 |
+
import warnings
|
| 9 |
+
from functools import partial
|
| 10 |
+
from typing import Callable, List, Optional
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
from pytorch_lightning import LightningDataModule
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class MultiDataLoader:
|
| 17 |
+
# NOTE: Please check MMF's MultiDataLoader if you want to support
|
| 18 |
+
# epoch based sampling funcs.
|
| 19 |
+
def __init__(
|
| 20 |
+
self,
|
| 21 |
+
loaders: List[torch.utils.data.DataLoader],
|
| 22 |
+
sampling_func: Optional[Callable] = None,
|
| 23 |
+
):
|
| 24 |
+
"""MultiDataLoader takes in a list of dataloaders and a sampling function
|
| 25 |
+
and cycles between these dataloaders after each batch based on the index
|
| 26 |
+
provided by the sampling function passed. Useful for doing multi-tasking
|
| 27 |
+
over multiple datasets
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
loaders (List[torch.utils.data.DataLoader]): List of dataloaders on
|
| 31 |
+
which the multitasking has to be done.
|
| 32 |
+
|
| 33 |
+
sampling_func (Optional[Callable], optional): Function which will return
|
| 34 |
+
the next index to be selected. Defaults to equally weight sampling.
|
| 35 |
+
"""
|
| 36 |
+
if loaders is None or len(loaders) == 0:
|
| 37 |
+
warnings.warn(
|
| 38 |
+
"Empty loaders passed into MultiDataLoader. This can have "
|
| 39 |
+
"unintended consequences."
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
if sampling_func is None:
|
| 43 |
+
sampling_func = partial(random.choice, range(len(loaders)))
|
| 44 |
+
|
| 45 |
+
self.sampling_func = sampling_func
|
| 46 |
+
self.loaders = loaders
|
| 47 |
+
self.num_datasets = len(self.loaders)
|
| 48 |
+
self.iterators = [None for _ in loaders]
|
| 49 |
+
self.current_index = 0
|
| 50 |
+
self.set_samplers()
|
| 51 |
+
|
| 52 |
+
def set_samplers(self):
|
| 53 |
+
self.samplers: List[torch.utils.data.Sampler] = []
|
| 54 |
+
for loader in self.loaders:
|
| 55 |
+
if hasattr(loader, "sampler"):
|
| 56 |
+
self.samplers.append(loader.sampler)
|
| 57 |
+
|
| 58 |
+
def __iter__(self):
|
| 59 |
+
self.iterators = []
|
| 60 |
+
|
| 61 |
+
for loader in self.loaders:
|
| 62 |
+
self.iterators.append(iter(loader))
|
| 63 |
+
|
| 64 |
+
self.change_dataloader()
|
| 65 |
+
|
| 66 |
+
return self
|
| 67 |
+
|
| 68 |
+
def __next__(self):
|
| 69 |
+
"""
|
| 70 |
+
Calculation of next batch is performed using following logic.
|
| 71 |
+
|
| 72 |
+
Current chosen iterator is set in the change_dataloader function
|
| 73 |
+
based on the `sampling_func` function passed to `__init__` of the
|
| 74 |
+
dataloader which is called to get the index of next selected dataloader.
|
| 75 |
+
|
| 76 |
+
If we get the next batch from iterator without any StopIteration exception,
|
| 77 |
+
we return it as it is.
|
| 78 |
+
|
| 79 |
+
Epochs don't make sense in case of using `sampling_func` unless you add
|
| 80 |
+
extra logic to support epoch-based sampling functions. MMF does this in
|
| 81 |
+
a different way, so take a look at IterationStrategies there to understand
|
| 82 |
+
how this can be possibly done.
|
| 83 |
+
|
| 84 |
+
Think of a case of random (equal) proportional sampling for dataset x and y
|
| 85 |
+
where x is half the size of y. When x will complete its 2 epochs, y will
|
| 86 |
+
have only 1 epoch completed. **So please don't use max_epochs or epoch
|
| 87 |
+
based training in this case as it won't be honored**. If an iterator is
|
| 88 |
+
finished, we just reignite it in this case and finished iterators
|
| 89 |
+
variable isn't used. This means that this case will never reach the
|
| 90 |
+
__iter__ function ever again.
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
Returns:
|
| 94 |
+
Dict: Contains two keys, one "batch" containing the batch from current
|
| 95 |
+
selected dataloader and "datamodule_index" which is index of
|
| 96 |
+
currently selected dataloader.
|
| 97 |
+
"""
|
| 98 |
+
self.change_dataloader()
|
| 99 |
+
try:
|
| 100 |
+
next_batch = next(self.current_iterator)
|
| 101 |
+
except StopIteration:
|
| 102 |
+
iterator = iter(self.loaders[self.current_index])
|
| 103 |
+
self.iterators[self.current_index] = iterator
|
| 104 |
+
self.current_iterator = iterator
|
| 105 |
+
next_batch = next(self.current_iterator)
|
| 106 |
+
|
| 107 |
+
return {"batch": next_batch, "datamodule_index": self.current_index}
|
| 108 |
+
|
| 109 |
+
def change_dataloader(self):
|
| 110 |
+
choice = 0
|
| 111 |
+
|
| 112 |
+
if self.num_datasets <= 1:
|
| 113 |
+
self.current_index = choice
|
| 114 |
+
self.current_iterator = self.iterators[self.current_index]
|
| 115 |
+
return
|
| 116 |
+
|
| 117 |
+
choice = [self.sampling_func()]
|
| 118 |
+
if torch.distributed.is_available() and torch.distributed.is_initialized():
|
| 119 |
+
# This broadcast is probably unnecessary with lightning if everything
|
| 120 |
+
# is already properly seeded. But,to be on safe side, we can still
|
| 121 |
+
# do this.
|
| 122 |
+
# There are also some smarter ways to do this to avoid any broadcasting
|
| 123 |
+
# by basically having a fixed generator with a fixed seed which will
|
| 124 |
+
# always work deterministically.
|
| 125 |
+
# TODO: Check if not doing this provides any speed benefits.
|
| 126 |
+
torch.distributed.broadcast_object_list(choice, 0)
|
| 127 |
+
|
| 128 |
+
self.current_index = choice[0]
|
| 129 |
+
self.current_iterator = self.iterators[self.current_index]
|
| 130 |
+
|
| 131 |
+
def set_epoch(self, epoch: int):
|
| 132 |
+
if torch.distributed.is_available() and torch.distributed.is_initialized():
|
| 133 |
+
for sampler in self.samplers:
|
| 134 |
+
if sampler is not None and hasattr(sampler, "set_epoch"):
|
| 135 |
+
sampler.set_epoch(epoch)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
class MultiDataModule(LightningDataModule):
|
| 139 |
+
"""MultiDataModule is just an abstraction over MultiDataLoader
|
| 140 |
+
that will allow us to integrate it with Lightning.
|
| 141 |
+
"""
|
| 142 |
+
|
| 143 |
+
# NOTE: Add rest of the functions that should be called on child datamodules
|
| 144 |
+
# as required
|
| 145 |
+
def __init__(
|
| 146 |
+
self,
|
| 147 |
+
datamodules: List[LightningDataModule],
|
| 148 |
+
sampling_func: Optional[Callable] = None,
|
| 149 |
+
):
|
| 150 |
+
super().__init__()
|
| 151 |
+
self.datamodules = datamodules
|
| 152 |
+
self.sampling_func = sampling_func
|
| 153 |
+
self.current_datamodule_idx = 0
|
| 154 |
+
|
| 155 |
+
def setup(self, stage=None):
|
| 156 |
+
for datamodule in self.datamodules:
|
| 157 |
+
datamodule.setup(stage)
|
| 158 |
+
|
| 159 |
+
def prepare_data(self):
|
| 160 |
+
for datamodule in self.datamodules:
|
| 161 |
+
datamodule.prepare_data()
|
| 162 |
+
|
| 163 |
+
def train_dataloader(self) -> MultiDataLoader:
|
| 164 |
+
# TODO: Fix assign inconsistency
|
| 165 |
+
return self._build_multi_dataloader("train")
|
| 166 |
+
|
| 167 |
+
def val_dataloader(self) -> MultiDataLoader:
|
| 168 |
+
return self._build_multi_dataloader("val")
|
| 169 |
+
|
| 170 |
+
def test_dataloader(self) -> MultiDataLoader:
|
| 171 |
+
return self._build_multi_dataloader("test")
|
| 172 |
+
|
| 173 |
+
def _build_multi_dataloader(self, split="train"):
|
| 174 |
+
dataloaders = []
|
| 175 |
+
for datamodule in self.datamodules:
|
| 176 |
+
dataloaders.append(getattr(datamodule, f"{split}_dataloader")())
|
| 177 |
+
|
| 178 |
+
return MultiDataLoader(dataloaders, self.sampling_func)
|
| 179 |
+
|
| 180 |
+
def on_before_batch_transfer(self, batch, *args):
|
| 181 |
+
batch, index = batch["batch"], batch["datamodule_index"]
|
| 182 |
+
self.current_datamodule_idx = index
|
| 183 |
+
return self.datamodules[self.current_datamodule_idx].on_before_batch_transfer(
|
| 184 |
+
batch, *args
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
def on_after_batch_transfer(self, batch, *args):
|
| 188 |
+
return self.datamodules[self.current_datamodule_idx].on_after_batch_transfer(
|
| 189 |
+
batch, *args
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
def teardown(self, stage):
|
| 193 |
+
for datamodule in self.datamodules:
|
| 194 |
+
datamodule.teardown(stage)
|
multimodal/examples/flava/callbacks/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from .multimodal_eval import * # noqa F401
|
multimodal/examples/flava/callbacks/multimodal_eval.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import logging
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from flava.data import default_text_transform, VL_MAX_LENGTH_DEFAULT
|
| 11 |
+
from flava.data.imagenet_zeroshot_data import (
|
| 12 |
+
imagenet_classnames,
|
| 13 |
+
openai_imagenet_template,
|
| 14 |
+
)
|
| 15 |
+
from pytorch_lightning import Callback, LightningDataModule
|
| 16 |
+
from pytorch_lightning.utilities import rank_zero_only
|
| 17 |
+
from tqdm import tqdm
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
logger = logging.getLogger(__name__)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def _zero_shot_classifier(model, device, text_transform, *args, **kwargs):
|
| 24 |
+
zeroshot_weights = []
|
| 25 |
+
for classname in tqdm(imagenet_classnames):
|
| 26 |
+
texts = text_transform(
|
| 27 |
+
[template(classname) for template in openai_imagenet_template]
|
| 28 |
+
)["input_ids"]
|
| 29 |
+
texts = texts.to(device)
|
| 30 |
+
class_embeddings = model.encode_text(texts)
|
| 31 |
+
class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
|
| 32 |
+
class_embedding = class_embeddings.mean(dim=0)
|
| 33 |
+
class_embedding /= class_embedding.norm()
|
| 34 |
+
zeroshot_weights.append(class_embedding)
|
| 35 |
+
|
| 36 |
+
zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(device)
|
| 37 |
+
return zeroshot_weights
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def _accuracy(output, target, topk=(1,)):
|
| 41 |
+
pred = output.topk(max(topk), 1, True, True)[1].t()
|
| 42 |
+
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
| 43 |
+
return [
|
| 44 |
+
float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy())
|
| 45 |
+
for k in topk
|
| 46 |
+
]
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
@rank_zero_only
|
| 50 |
+
def run_imagenet_zero_shot(model, dataloader, device, text_transform, *args, **kwargs):
|
| 51 |
+
logger.info("Starting ImageNet Zero-Shot Eval")
|
| 52 |
+
logger.info("Building classifier")
|
| 53 |
+
classifier = _zero_shot_classifier(model, device, text_transform)
|
| 54 |
+
logger.info("Classifier built")
|
| 55 |
+
top1, top5, n = 0.0, 0.0, 0.0
|
| 56 |
+
for sample in tqdm(dataloader):
|
| 57 |
+
images = sample["image"]
|
| 58 |
+
target = sample["label"]
|
| 59 |
+
images = images.to(device)
|
| 60 |
+
target = target.to(device)
|
| 61 |
+
|
| 62 |
+
# predict
|
| 63 |
+
# if hasattr(model, "module"):
|
| 64 |
+
# image_features = model.module.encode_image({"image": images})
|
| 65 |
+
# else:
|
| 66 |
+
image_features = model.encode_image(images)
|
| 67 |
+
image_features /= image_features.norm(dim=-1, keepdim=True)
|
| 68 |
+
logits = 100.0 * image_features @ classifier
|
| 69 |
+
|
| 70 |
+
# measure accuracy
|
| 71 |
+
acc1, acc5 = _accuracy(logits, target, topk=(1, 5))
|
| 72 |
+
top1 += acc1
|
| 73 |
+
top5 += acc5
|
| 74 |
+
n += images.size(0)
|
| 75 |
+
|
| 76 |
+
top1 = top1 / n
|
| 77 |
+
top5 = top5 / n
|
| 78 |
+
results = {}
|
| 79 |
+
results["imagenet-zeroshot-val-top1"] = top1
|
| 80 |
+
results["imagenet-zeroshot-val-top5"] = top5
|
| 81 |
+
return results
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class MultimodalEvalCallback(Callback):
|
| 85 |
+
def __init__(self, imagenet_datamodule: LightningDataModule, *args, **kwargs):
|
| 86 |
+
super().__init__()
|
| 87 |
+
self.imagenet_val_dataloader = imagenet_datamodule.val_dataloader()
|
| 88 |
+
self.text_transform = default_text_transform(
|
| 89 |
+
max_text_length=VL_MAX_LENGTH_DEFAULT
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
@torch.no_grad()
|
| 93 |
+
def on_validation_start(self, trainer, pl_module, **kwargs) -> None:
|
| 94 |
+
metrics = run_imagenet_zero_shot(
|
| 95 |
+
pl_module.model,
|
| 96 |
+
self.imagenet_val_dataloader,
|
| 97 |
+
pl_module.device,
|
| 98 |
+
self.text_transform,
|
| 99 |
+
)
|
| 100 |
+
if metrics is not None:
|
| 101 |
+
for key in metrics:
|
| 102 |
+
self.log(
|
| 103 |
+
f"val/{key}",
|
| 104 |
+
metrics[key],
|
| 105 |
+
prog_bar=True,
|
| 106 |
+
logger=True,
|
| 107 |
+
rank_zero_only=True,
|
| 108 |
+
)
|
multimodal/examples/flava/configs/finetuning/qnli.yaml
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Note that in original FLAVA paper, only Logistic Regression numbers were provided for image datasets.
|
| 2 |
+
_target_: flava.definitions.FLAVAArguments
|
| 3 |
+
training:
|
| 4 |
+
_target_: flava.definitions.TrainingArguments
|
| 5 |
+
lightning:
|
| 6 |
+
max_steps: 33112
|
| 7 |
+
gpus: 1
|
| 8 |
+
val_check_interval: 1000
|
| 9 |
+
num_sanity_val_steps: 0
|
| 10 |
+
strategy: ddp
|
| 11 |
+
lightning_checkpoint:
|
| 12 |
+
dirpath: "."
|
| 13 |
+
filename: flava-{epoch:02d}-{step}
|
| 14 |
+
save_last: true
|
| 15 |
+
every_n_train_steps: 1000
|
| 16 |
+
save_on_train_epoch_end: true
|
| 17 |
+
verbose: true
|
| 18 |
+
monitor: validation/accuracy/classification
|
| 19 |
+
mode: max
|
| 20 |
+
lightning_load_from_checkpoint: null
|
| 21 |
+
seed: -1
|
| 22 |
+
batch_size: 32
|
| 23 |
+
num_workers: 4
|
| 24 |
+
learning_rate: 1e-5
|
| 25 |
+
adam_eps: 1e-6
|
| 26 |
+
adam_weight_decay: 0.1
|
| 27 |
+
adam_betas:
|
| 28 |
+
- 0.9
|
| 29 |
+
- 0.98
|
| 30 |
+
warmup_steps: 1986
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
datasets:
|
| 34 |
+
_target_: flava.definitions.TrainingDatasetsInfo
|
| 35 |
+
selected:
|
| 36 |
+
- text
|
| 37 |
+
num_classes: 2
|
| 38 |
+
text:
|
| 39 |
+
_target_: flava.definitions.TrainingSingleDatasetInfo
|
| 40 |
+
train:
|
| 41 |
+
- _target_: flava.definitions.HFDatasetInfo
|
| 42 |
+
key: glue
|
| 43 |
+
subset: qnli
|
| 44 |
+
rename_columns:
|
| 45 |
+
- ["question", "sentence1"]
|
| 46 |
+
- ["sentence", "sentence2"]
|
| 47 |
+
datamodule_extra_kwargs:
|
| 48 |
+
text_columns: ["sentence1", "sentence2"]
|
multimodal/examples/flava/configs/finetuning/rendered_sst2.yaml
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Note that in original FLAVA paper, only Logistic Regression numbers were provided for image datasets.
|
| 2 |
+
_target_: flava.definitions.FLAVAArguments
|
| 3 |
+
training:
|
| 4 |
+
_target_: flava.definitions.TrainingArguments
|
| 5 |
+
lightning:
|
| 6 |
+
max_steps: 20935
|
| 7 |
+
gpus: -1
|
| 8 |
+
val_check_interval: 100
|
| 9 |
+
num_sanity_val_steps: 0
|
| 10 |
+
strategy: ddp
|
| 11 |
+
lightning_checkpoint:
|
| 12 |
+
dirpath: "."
|
| 13 |
+
filename: flava-{epoch:02d}-{step}
|
| 14 |
+
save_last: true
|
| 15 |
+
every_n_train_steps: 1000
|
| 16 |
+
save_on_train_epoch_end: true
|
| 17 |
+
verbose: true
|
| 18 |
+
lightning_load_from_checkpoint: null
|
| 19 |
+
seed: -1
|
| 20 |
+
batch_size: 32
|
| 21 |
+
num_workers: 4
|
| 22 |
+
learning_rate: 1e-5
|
| 23 |
+
adam_eps: 1e-8
|
| 24 |
+
adam_weight_decay: 1e-2
|
| 25 |
+
warmup_steps: 1256
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
datasets:
|
| 29 |
+
_target_: flava.definitions.TrainingDatasetsInfo
|
| 30 |
+
selected:
|
| 31 |
+
- image
|
| 32 |
+
num_classes: 2
|
| 33 |
+
image:
|
| 34 |
+
_target_: flava.definitions.TrainingSingleDatasetInfo
|
| 35 |
+
train:
|
| 36 |
+
- _target_: flava.definitions.TorchVisionDatasetInfo
|
| 37 |
+
key: RenderedSST2
|
multimodal/examples/flava/configs/pretraining/debug.yaml
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_target_: flava.definitions.FLAVAArguments
|
| 2 |
+
training:
|
| 3 |
+
_target_: flava.definitions.TrainingArguments
|
| 4 |
+
lightning:
|
| 5 |
+
max_steps: 450000
|
| 6 |
+
gpus: -1
|
| 7 |
+
val_check_interval: 10000
|
| 8 |
+
num_sanity_val_steps: 0
|
| 9 |
+
strategy: ddp
|
| 10 |
+
lightning_checkpoint:
|
| 11 |
+
dirpath: "."
|
| 12 |
+
filename: flava-{epoch:02d}-{step}
|
| 13 |
+
save_last: true
|
| 14 |
+
every_n_train_steps: 1000
|
| 15 |
+
save_on_train_epoch_end: true
|
| 16 |
+
verbose: true
|
| 17 |
+
lightning_load_from_checkpoint: null
|
| 18 |
+
seed: -1
|
| 19 |
+
batch_size: 8
|
| 20 |
+
num_workers: 4
|
| 21 |
+
learning_rate: 2e-4
|
| 22 |
+
adam_eps: 1e-8
|
| 23 |
+
adam_weight_decay: 1e-2
|
| 24 |
+
warmup_steps: 2000
|
| 25 |
+
|
| 26 |
+
datasets:
|
| 27 |
+
_target_: flava.definitions.TrainingDatasetsInfo
|
| 28 |
+
selected:
|
| 29 |
+
- image
|
| 30 |
+
- vl
|
| 31 |
+
- text
|
| 32 |
+
image:
|
| 33 |
+
_target_: flava.definitions.TrainingSingleDatasetInfo
|
| 34 |
+
train:
|
| 35 |
+
- _target_: flava.definitions.HFDatasetInfo
|
| 36 |
+
key: imagenet-1k
|
| 37 |
+
subset: default
|
| 38 |
+
text:
|
| 39 |
+
_target_: flava.definitions.TrainingSingleDatasetInfo
|
| 40 |
+
train:
|
| 41 |
+
- _target_: flava.definitions.HFDatasetInfo
|
| 42 |
+
key: wikitext
|
| 43 |
+
subset: wikitext-103-raw-v1
|
| 44 |
+
datamodule_extra_kwargs:
|
| 45 |
+
text_columns: ["text"]
|
| 46 |
+
vl:
|
| 47 |
+
_target_: flava.definitions.TrainingSingleDatasetInfo
|
| 48 |
+
train:
|
| 49 |
+
- _target_: flava.definitions.HFDatasetInfo
|
| 50 |
+
key: red_caps
|
| 51 |
+
subset: jellyfish
|
| 52 |
+
rename_columns:
|
| 53 |
+
- ["caption", "text"]
|
| 54 |
+
val:
|
| 55 |
+
- _target_: flava.definitions.HFDatasetInfo
|
| 56 |
+
key: red_caps
|
| 57 |
+
subset: jellyfish
|
| 58 |
+
rename_columns:
|
| 59 |
+
- ["caption", "text"]
|
| 60 |
+
split_key_mapping:
|
| 61 |
+
validation: train
|
multimodal/examples/flava/data/__init__.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from .transforms import * # noqa F401
|
| 8 |
+
from .utils import * # noqa F401
|
| 9 |
+
from .imagenet_zeroshot_data import * # noqa F401
|
| 10 |
+
from .datamodules import * # noqa F401
|
multimodal/examples/flava/data/datamodules.py
ADDED
|
@@ -0,0 +1,529 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
from functools import partial
|
| 9 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torchvision
|
| 13 |
+
from flava.definitions import HFDatasetInfo, TorchVisionDatasetInfo
|
| 14 |
+
from pytorch_lightning import LightningDataModule
|
| 15 |
+
from transformers import (
|
| 16 |
+
BertTokenizer,
|
| 17 |
+
DataCollatorForLanguageModeling,
|
| 18 |
+
DataCollatorForWholeWordMask,
|
| 19 |
+
DefaultDataCollator,
|
| 20 |
+
TRANSFORMERS_CACHE,
|
| 21 |
+
)
|
| 22 |
+
from transformers.data.data_collator import torch_default_data_collator
|
| 23 |
+
|
| 24 |
+
from .transforms import (
|
| 25 |
+
default_image_pretraining_transforms,
|
| 26 |
+
default_text_transform,
|
| 27 |
+
default_torchvision_transforms,
|
| 28 |
+
encode_text_batch,
|
| 29 |
+
pad_batch,
|
| 30 |
+
TEXT_DEFAULT_TOKENIZER,
|
| 31 |
+
TEXT_WHOLE_WORD_MASK_TOKENIZER,
|
| 32 |
+
VL_MAX_LENGTH_DEFAULT,
|
| 33 |
+
VLTransform,
|
| 34 |
+
)
|
| 35 |
+
from .utils import build_datasets_from_info, fetch_images
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def transform_image(transform, sample):
|
| 39 |
+
sample.update(transform(sample["image"]))
|
| 40 |
+
return sample
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class DataCollatorForWholeWordMaskRetainingBatch(DataCollatorForWholeWordMask):
|
| 44 |
+
def torch_call(
|
| 45 |
+
self, examples: List[Union[List[int], Any, Dict[str, Any]]]
|
| 46 |
+
) -> Dict[str, Any]:
|
| 47 |
+
masked_batch = super().torch_call(examples)
|
| 48 |
+
examples = torch_default_data_collator(examples)
|
| 49 |
+
examples["input_ids"] = masked_batch["input_ids"]
|
| 50 |
+
examples["labels"] = masked_batch["labels"]
|
| 51 |
+
return examples
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class ImageDataModule(LightningDataModule):
|
| 55 |
+
def __init__(
|
| 56 |
+
self,
|
| 57 |
+
train_infos: List[HFDatasetInfo],
|
| 58 |
+
val_infos: Optional[List[HFDatasetInfo]] = None,
|
| 59 |
+
transforms: Optional[Tuple[Callable, Callable]] = None,
|
| 60 |
+
batch_size: int = 32,
|
| 61 |
+
num_workers: int = 4,
|
| 62 |
+
allow_uneven_batches: bool = False,
|
| 63 |
+
**kwargs: Any,
|
| 64 |
+
):
|
| 65 |
+
super().__init__()
|
| 66 |
+
self.train_dataset_infos = train_infos
|
| 67 |
+
self.val_dataset_infos = val_infos
|
| 68 |
+
if self.val_dataset_infos is None:
|
| 69 |
+
self.val_dataset_infos = train_infos
|
| 70 |
+
|
| 71 |
+
self.batch_size = batch_size
|
| 72 |
+
self.num_workers = num_workers
|
| 73 |
+
self.allow_uneven_batches = allow_uneven_batches
|
| 74 |
+
|
| 75 |
+
if transforms is None:
|
| 76 |
+
transforms = default_image_pretraining_transforms()
|
| 77 |
+
|
| 78 |
+
self.train_transform, self.test_transform = transforms
|
| 79 |
+
|
| 80 |
+
def setup(self, stage=None):
|
| 81 |
+
train_transform = partial(transform_image, self.train_transform)
|
| 82 |
+
val_transform = partial(transform_image, self.test_transform)
|
| 83 |
+
|
| 84 |
+
self.train_dataset = build_datasets_from_info(
|
| 85 |
+
self.train_dataset_infos, split="train"
|
| 86 |
+
)
|
| 87 |
+
self.train_dataset.set_transform(train_transform)
|
| 88 |
+
self.val_dataset = build_datasets_from_info(
|
| 89 |
+
self.val_dataset_infos, split="validation"
|
| 90 |
+
)
|
| 91 |
+
self.val_dataset.set_transform(val_transform)
|
| 92 |
+
|
| 93 |
+
def train_dataloader(self):
|
| 94 |
+
return torch.utils.data.DataLoader(
|
| 95 |
+
self.train_dataset,
|
| 96 |
+
batch_size=self.batch_size,
|
| 97 |
+
num_workers=self.num_workers,
|
| 98 |
+
sampler=None,
|
| 99 |
+
shuffle=True,
|
| 100 |
+
# uneven batches can cause distributed issues,
|
| 101 |
+
# drop last batch to prevent those.
|
| 102 |
+
# ideally, we don't need to drop these for unimodal cases
|
| 103 |
+
# but just to be safe
|
| 104 |
+
drop_last=True,
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
def val_dataloader(self):
|
| 108 |
+
return torch.utils.data.DataLoader(
|
| 109 |
+
self.val_dataset,
|
| 110 |
+
batch_size=self.batch_size,
|
| 111 |
+
num_workers=self.num_workers,
|
| 112 |
+
sampler=None,
|
| 113 |
+
shuffle=False,
|
| 114 |
+
# uneven batches can cause distributed issues,
|
| 115 |
+
# drop last batch to prevent those.
|
| 116 |
+
# ideally, we don't need to drop these for unimodal cases
|
| 117 |
+
# but just to be safe
|
| 118 |
+
drop_last=True,
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
def test_dataloader(self):
|
| 122 |
+
return self.val_dataloader()
|
| 123 |
+
|
| 124 |
+
def on_before_batch_transfer(self, batch, *args):
|
| 125 |
+
if batch["label"].size(0) < self.batch_size and not self.allow_uneven_batches:
|
| 126 |
+
batch = pad_batch(batch, self.batch_size)
|
| 127 |
+
return batch
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
class TextDataModule(LightningDataModule):
|
| 131 |
+
def __init__(
|
| 132 |
+
self,
|
| 133 |
+
train_infos: List[HFDatasetInfo],
|
| 134 |
+
text_columns: List[str],
|
| 135 |
+
val_infos: Optional[List[HFDatasetInfo]] = None,
|
| 136 |
+
tokenizer: Optional[Callable] = None,
|
| 137 |
+
max_length: int = 512,
|
| 138 |
+
batch_size: int = 32,
|
| 139 |
+
num_workers: int = 4,
|
| 140 |
+
allow_uneven_batches: bool = False,
|
| 141 |
+
**kwargs: Any,
|
| 142 |
+
):
|
| 143 |
+
super().__init__()
|
| 144 |
+
self.train_dataset_infos = train_infos
|
| 145 |
+
self.text_columns = text_columns
|
| 146 |
+
self.val_dataset_infos = val_infos
|
| 147 |
+
if self.val_dataset_infos is None:
|
| 148 |
+
self.val_dataset_infos = train_infos
|
| 149 |
+
self.tokenizer = tokenizer
|
| 150 |
+
self.max_length = max_length
|
| 151 |
+
self.batch_size = batch_size
|
| 152 |
+
self.num_workers = num_workers
|
| 153 |
+
self.allow_uneven_batches = allow_uneven_batches
|
| 154 |
+
|
| 155 |
+
def setup(self, stage=None):
|
| 156 |
+
if self.tokenizer is None:
|
| 157 |
+
self.tokenizer = BertTokenizer.from_pretrained(TEXT_DEFAULT_TOKENIZER)
|
| 158 |
+
transform = partial(
|
| 159 |
+
encode_text_batch,
|
| 160 |
+
tokenizer=self.tokenizer,
|
| 161 |
+
padding="max_length",
|
| 162 |
+
max_length=self.max_length,
|
| 163 |
+
truncation=True,
|
| 164 |
+
return_tensors="pt",
|
| 165 |
+
return_special_tokens_mask=True,
|
| 166 |
+
text_columns=self.text_columns,
|
| 167 |
+
return_batch=True,
|
| 168 |
+
)
|
| 169 |
+
self.train_dataset = build_datasets_from_info(
|
| 170 |
+
self.train_dataset_infos, split="train"
|
| 171 |
+
)
|
| 172 |
+
self.train_dataset.set_transform(transform)
|
| 173 |
+
self.val_dataset = build_datasets_from_info(
|
| 174 |
+
self.val_dataset_infos, split="validation"
|
| 175 |
+
)
|
| 176 |
+
self.val_dataset.set_transform(transform)
|
| 177 |
+
|
| 178 |
+
def train_dataloader(self):
|
| 179 |
+
return self._build_dataloader(self.train_dataset)
|
| 180 |
+
|
| 181 |
+
def val_dataloader(self):
|
| 182 |
+
return self._build_dataloader(self.val_dataset, shuffle=False)
|
| 183 |
+
|
| 184 |
+
def _build_dataloader(self, dataset, drop_last=False, shuffle=True):
|
| 185 |
+
return torch.utils.data.DataLoader(
|
| 186 |
+
dataset,
|
| 187 |
+
batch_size=self.batch_size,
|
| 188 |
+
num_workers=self.num_workers,
|
| 189 |
+
sampler=None,
|
| 190 |
+
shuffle=shuffle,
|
| 191 |
+
collate_fn=self._build_collator(),
|
| 192 |
+
drop_last=drop_last,
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
def _build_collator(self):
|
| 196 |
+
return DefaultDataCollator()
|
| 197 |
+
|
| 198 |
+
def on_before_batch_transfer(self, batch, *args):
|
| 199 |
+
batch.pop("token_type_ids", None)
|
| 200 |
+
mask = batch.pop("attention_mask", None)
|
| 201 |
+
if mask.size(0) < self.batch_size and not self.allow_uneven_batches:
|
| 202 |
+
batch = pad_batch(batch, self.batch_size)
|
| 203 |
+
return batch
|
| 204 |
+
|
| 205 |
+
def on_after_batch_transfer(self, batch, *args):
|
| 206 |
+
batch["text"] = batch.pop("input_ids")
|
| 207 |
+
return batch
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
class MLMDataModule(TextDataModule):
|
| 211 |
+
def __init__(
|
| 212 |
+
self,
|
| 213 |
+
train_infos: List[HFDatasetInfo],
|
| 214 |
+
text_columns: List[str],
|
| 215 |
+
val_infos: Optional[List[HFDatasetInfo]] = None,
|
| 216 |
+
mlm_probability: float = 0.15,
|
| 217 |
+
ignore_index: int = -1,
|
| 218 |
+
**kwargs: Any,
|
| 219 |
+
):
|
| 220 |
+
super().__init__(train_infos, text_columns, val_infos, **kwargs)
|
| 221 |
+
self.mlm_probability = mlm_probability
|
| 222 |
+
self.ignore_index = ignore_index
|
| 223 |
+
|
| 224 |
+
def setup(self, stage=None):
|
| 225 |
+
if self.tokenizer is None:
|
| 226 |
+
self.tokenizer = BertTokenizer.from_pretrained(TEXT_DEFAULT_TOKENIZER)
|
| 227 |
+
transform = partial(
|
| 228 |
+
encode_text_batch,
|
| 229 |
+
tokenizer=self.tokenizer,
|
| 230 |
+
padding="max_length",
|
| 231 |
+
max_length=self.max_length,
|
| 232 |
+
truncation=True,
|
| 233 |
+
return_tensors="pt",
|
| 234 |
+
return_special_tokens_mask=True,
|
| 235 |
+
text_columns=self.text_columns,
|
| 236 |
+
return_batch=False,
|
| 237 |
+
)
|
| 238 |
+
self.train_dataset = build_datasets_from_info(
|
| 239 |
+
self.train_dataset_infos, split="train"
|
| 240 |
+
)
|
| 241 |
+
self.train_dataset.set_transform(transform)
|
| 242 |
+
self.val_dataset = build_datasets_from_info(
|
| 243 |
+
self.val_dataset_infos, split="validation"
|
| 244 |
+
)
|
| 245 |
+
self.val_dataset.set_transform(transform)
|
| 246 |
+
|
| 247 |
+
def _build_dataloader(self, dataset, drop_last=True, shuffle=True):
|
| 248 |
+
# uneven batches can cause distributed issues,
|
| 249 |
+
# drop last batch to prevent those.
|
| 250 |
+
# ideally, we don't need to drop these for unimodal cases
|
| 251 |
+
# but just to be safe
|
| 252 |
+
return super()._build_dataloader(dataset, drop_last=drop_last, shuffle=shuffle)
|
| 253 |
+
|
| 254 |
+
def _build_collator(self):
|
| 255 |
+
return DataCollatorForLanguageModeling(
|
| 256 |
+
self.tokenizer, mlm_probability=self.mlm_probability
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
def on_after_batch_transfer(self, batch, *args):
|
| 260 |
+
batch["text_masked"] = batch.pop("input_ids")
|
| 261 |
+
batch["mlm_labels"] = batch.pop("labels")
|
| 262 |
+
batch["mlm_labels"][batch["mlm_labels"] == -100] = self.ignore_index
|
| 263 |
+
return batch
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
class VLDataModule(LightningDataModule):
|
| 267 |
+
def __init__(
|
| 268 |
+
self,
|
| 269 |
+
train_infos: List[HFDatasetInfo],
|
| 270 |
+
val_infos: List[HFDatasetInfo],
|
| 271 |
+
text_transform: Optional[Callable] = None,
|
| 272 |
+
image_transforms: Optional[Tuple[Callable, Callable]] = None,
|
| 273 |
+
mlm_probablity: float = 0.15,
|
| 274 |
+
batch_size: int = 32,
|
| 275 |
+
num_workers: int = 4,
|
| 276 |
+
finetuning: bool = False,
|
| 277 |
+
ignore_index: int = -1,
|
| 278 |
+
itm_probability: float = 0.1,
|
| 279 |
+
allow_uneven_batches: bool = False,
|
| 280 |
+
fetch_num_threads: int = 4,
|
| 281 |
+
fetch_retries: int = 0,
|
| 282 |
+
fetch_sleep_timer: int = 0,
|
| 283 |
+
fetch_timeout: Optional[float] = None,
|
| 284 |
+
fetch_batch_size: int = 50,
|
| 285 |
+
**kwargs,
|
| 286 |
+
):
|
| 287 |
+
super().__init__()
|
| 288 |
+
|
| 289 |
+
self.train_dataset_infos = train_infos
|
| 290 |
+
self.val_dataset_infos = val_infos
|
| 291 |
+
if self.val_dataset_infos is None:
|
| 292 |
+
self.val_dataset_infos = train_infos
|
| 293 |
+
if image_transforms is None:
|
| 294 |
+
if not finetuning:
|
| 295 |
+
image_transforms = default_image_pretraining_transforms()
|
| 296 |
+
else:
|
| 297 |
+
image_transforms = default_torchvision_transforms(use_dict=True)
|
| 298 |
+
|
| 299 |
+
self.train_image_transform, self.test_image_transform = image_transforms
|
| 300 |
+
self.text_transform = text_transform
|
| 301 |
+
self.mlm_probability = mlm_probablity
|
| 302 |
+
self.batch_size = batch_size
|
| 303 |
+
self.num_workers = num_workers
|
| 304 |
+
self.ignore_index = ignore_index
|
| 305 |
+
self.itm_probability = itm_probability
|
| 306 |
+
self.allow_uneven_batches = allow_uneven_batches
|
| 307 |
+
self.fetch_num_threads = fetch_num_threads
|
| 308 |
+
self.fetch_retries = fetch_retries
|
| 309 |
+
self.fetch_sleep_timer = fetch_sleep_timer
|
| 310 |
+
self.fetch_timeout = fetch_timeout
|
| 311 |
+
self.fetch_batch_size = fetch_batch_size
|
| 312 |
+
|
| 313 |
+
def setup(self, stage=None):
|
| 314 |
+
if self.text_transform is None:
|
| 315 |
+
# TODO Update to use whole word mask vocab
|
| 316 |
+
text_tokenizer = BertTokenizer.from_pretrained(
|
| 317 |
+
TEXT_WHOLE_WORD_MASK_TOKENIZER
|
| 318 |
+
)
|
| 319 |
+
self.text_transform = default_text_transform(
|
| 320 |
+
text_tokenizer, max_text_length=VL_MAX_LENGTH_DEFAULT
|
| 321 |
+
)
|
| 322 |
+
self.text_tokenizer = self.text_transform.keywords["tokenizer"]
|
| 323 |
+
train_vl_transform = VLTransform(
|
| 324 |
+
self.train_image_transform, self.text_transform
|
| 325 |
+
)
|
| 326 |
+
val_vl_transform = VLTransform(self.test_image_transform, self.text_transform)
|
| 327 |
+
|
| 328 |
+
train_dataset = build_datasets_from_info(
|
| 329 |
+
self.train_dataset_infos, split="train"
|
| 330 |
+
)
|
| 331 |
+
train_dataset = train_dataset.map(
|
| 332 |
+
fetch_images,
|
| 333 |
+
batched=True,
|
| 334 |
+
batch_size=self.fetch_batch_size,
|
| 335 |
+
fn_kwargs={
|
| 336 |
+
"num_threads": self.fetch_num_threads,
|
| 337 |
+
"timeout": self.fetch_timeout,
|
| 338 |
+
"retries": self.fetch_retries,
|
| 339 |
+
"sleep_timer": self.fetch_sleep_timer,
|
| 340 |
+
},
|
| 341 |
+
)
|
| 342 |
+
train_dataset = train_dataset.filter(
|
| 343 |
+
lambda example: example["image"] is not None
|
| 344 |
+
)
|
| 345 |
+
self.train_dataset = train_dataset
|
| 346 |
+
self.train_dataset.set_transform(
|
| 347 |
+
partial(
|
| 348 |
+
train_vl_transform,
|
| 349 |
+
dataset=train_dataset.filter(lambda example: True),
|
| 350 |
+
itm_probability=self.itm_probability,
|
| 351 |
+
)
|
| 352 |
+
)
|
| 353 |
+
|
| 354 |
+
val_dataset = build_datasets_from_info(
|
| 355 |
+
self.val_dataset_infos, split="validation"
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
val_dataset = val_dataset.map(
|
| 359 |
+
fetch_images,
|
| 360 |
+
batched=True,
|
| 361 |
+
batch_size=self.fetch_batch_size,
|
| 362 |
+
fn_kwargs={
|
| 363 |
+
"num_threads": self.fetch_num_threads,
|
| 364 |
+
"timeout": self.fetch_timeout,
|
| 365 |
+
"retries": self.fetch_retries,
|
| 366 |
+
"sleep_timer": self.fetch_sleep_timer,
|
| 367 |
+
},
|
| 368 |
+
)
|
| 369 |
+
val_dataset = val_dataset.filter(lambda example: example["image"] is not None)
|
| 370 |
+
self.val_dataset = val_dataset
|
| 371 |
+
self.val_dataset.set_transform(
|
| 372 |
+
partial(
|
| 373 |
+
val_vl_transform,
|
| 374 |
+
dataset=self.val_dataset.filter(
|
| 375 |
+
lambda example: True
|
| 376 |
+
), # Pass a copy to transform
|
| 377 |
+
itm_probability=self.itm_probability,
|
| 378 |
+
)
|
| 379 |
+
)
|
| 380 |
+
|
| 381 |
+
def train_dataloader(self):
|
| 382 |
+
return torch.utils.data.DataLoader(
|
| 383 |
+
self.train_dataset,
|
| 384 |
+
batch_size=self.batch_size,
|
| 385 |
+
num_workers=self.num_workers,
|
| 386 |
+
sampler=None,
|
| 387 |
+
shuffle=True,
|
| 388 |
+
collate_fn=self._build_collator(),
|
| 389 |
+
# uneven batches can cause distributed issues,
|
| 390 |
+
# drop last batch to prevent those.
|
| 391 |
+
drop_last=True,
|
| 392 |
+
)
|
| 393 |
+
|
| 394 |
+
def val_dataloader(self):
|
| 395 |
+
return torch.utils.data.DataLoader(
|
| 396 |
+
self.val_dataset,
|
| 397 |
+
batch_size=self.batch_size,
|
| 398 |
+
num_workers=self.num_workers,
|
| 399 |
+
sampler=None,
|
| 400 |
+
shuffle=False,
|
| 401 |
+
collate_fn=self._build_collator(),
|
| 402 |
+
# uneven batches can cause distributed issues,
|
| 403 |
+
# drop last batch to prevent those.
|
| 404 |
+
drop_last=True,
|
| 405 |
+
)
|
| 406 |
+
|
| 407 |
+
def _build_collator(self):
|
| 408 |
+
return DataCollatorForWholeWordMaskRetainingBatch(
|
| 409 |
+
self.text_tokenizer, mlm_probability=self.mlm_probability
|
| 410 |
+
)
|
| 411 |
+
|
| 412 |
+
def on_before_batch_transfer(self, batch, *args):
|
| 413 |
+
batch.pop("token_type_ids", None)
|
| 414 |
+
mask = batch.pop("attention_mask", None)
|
| 415 |
+
if (
|
| 416 |
+
mask is not None
|
| 417 |
+
and mask.size(0) < self.batch_size
|
| 418 |
+
and not self.allow_uneven_batches
|
| 419 |
+
):
|
| 420 |
+
batch = pad_batch(batch, self.batch_size)
|
| 421 |
+
return batch
|
| 422 |
+
|
| 423 |
+
def on_after_batch_transfer(self, batch, *args):
|
| 424 |
+
text_masked = batch.pop("input_ids")
|
| 425 |
+
mlm_labels = batch.pop("labels", None)
|
| 426 |
+
mlm_labels[mlm_labels == -100] = self.ignore_index
|
| 427 |
+
text = text_masked.detach().clone()
|
| 428 |
+
text[mlm_labels != -1] = mlm_labels[mlm_labels != -1]
|
| 429 |
+
batch.update(
|
| 430 |
+
{"mlm_labels": mlm_labels, "text": text, "text_masked": text_masked}
|
| 431 |
+
)
|
| 432 |
+
return batch
|
| 433 |
+
|
| 434 |
+
|
| 435 |
+
class TorchVisionDataModule(LightningDataModule):
|
| 436 |
+
def __init__(
|
| 437 |
+
self,
|
| 438 |
+
train_infos: List[TorchVisionDatasetInfo],
|
| 439 |
+
# Val info is not used for torchvision datamodule, but kept to keep things consistent
|
| 440 |
+
val_infos: Optional[List[TorchVisionDatasetInfo]] = None,
|
| 441 |
+
dataset_root: Optional[str] = None,
|
| 442 |
+
image_transforms: Optional[Tuple[Callable, Callable]] = None,
|
| 443 |
+
batch_size: int = 32,
|
| 444 |
+
num_workers: int = 4,
|
| 445 |
+
**kwargs: Any,
|
| 446 |
+
):
|
| 447 |
+
super().__init__()
|
| 448 |
+
self.train_info = train_infos[0]
|
| 449 |
+
if val_infos is None:
|
| 450 |
+
val_infos = train_infos
|
| 451 |
+
self.val_info = val_infos[0]
|
| 452 |
+
|
| 453 |
+
self.train_class_ptr, self.train_root = self._parse_info(
|
| 454 |
+
self.train_info, dataset_root=dataset_root
|
| 455 |
+
)
|
| 456 |
+
self.val_class_ptr, self.val_root = self._parse_info(
|
| 457 |
+
self.val_info, dataset_root=dataset_root
|
| 458 |
+
)
|
| 459 |
+
|
| 460 |
+
if image_transforms is None:
|
| 461 |
+
image_transforms = default_torchvision_transforms()
|
| 462 |
+
|
| 463 |
+
self.train_transform, self.test_transform = image_transforms
|
| 464 |
+
self.batch_size = batch_size
|
| 465 |
+
self.num_workers = num_workers
|
| 466 |
+
|
| 467 |
+
def _parse_info(
|
| 468 |
+
self, info: TorchVisionDatasetInfo, dataset_root: Optional[str] = None
|
| 469 |
+
):
|
| 470 |
+
assert hasattr(
|
| 471 |
+
torchvision.datasets, info.key
|
| 472 |
+
), f"No dataset named {info.key} present in torchvision.datasets"
|
| 473 |
+
class_ptr = getattr(torchvision.datasets, info.key)
|
| 474 |
+
if dataset_root is None:
|
| 475 |
+
dataset_root = os.path.join(TRANSFORMERS_CACHE, "datasets", "torchvision")
|
| 476 |
+
dataset_root = os.path.join(dataset_root, class_ptr.__name__.lower())
|
| 477 |
+
os.makedirs(dataset_root, exist_ok=True)
|
| 478 |
+
|
| 479 |
+
return class_ptr, dataset_root
|
| 480 |
+
|
| 481 |
+
def setup(self, stage=None):
|
| 482 |
+
self.train_dataset = self.train_class_ptr(
|
| 483 |
+
self.train_root,
|
| 484 |
+
split=self.train_info.train_split,
|
| 485 |
+
transform=self.train_transform,
|
| 486 |
+
download=True,
|
| 487 |
+
)
|
| 488 |
+
|
| 489 |
+
if self.val_info.has_val:
|
| 490 |
+
self.val_dataset = self.val_class_ptr(
|
| 491 |
+
self.val_root,
|
| 492 |
+
split=self.val_info.val_split,
|
| 493 |
+
transform=self.test_transform,
|
| 494 |
+
download=True,
|
| 495 |
+
)
|
| 496 |
+
|
| 497 |
+
self.test_dataset = self.val_class_ptr(
|
| 498 |
+
self.val_root,
|
| 499 |
+
split=self.val_info.test_split,
|
| 500 |
+
transform=self.test_transform,
|
| 501 |
+
download=True,
|
| 502 |
+
)
|
| 503 |
+
|
| 504 |
+
def train_dataloader(self):
|
| 505 |
+
return self._build_dataloader(self.train_dataset)
|
| 506 |
+
|
| 507 |
+
def val_dataloader(self):
|
| 508 |
+
if self.val_info.has_val:
|
| 509 |
+
dataset = self.val_dataset
|
| 510 |
+
else:
|
| 511 |
+
dataset = self.test_dataset
|
| 512 |
+
|
| 513 |
+
return self._build_dataloader(dataset, shuffle=False)
|
| 514 |
+
|
| 515 |
+
def test_dataloader(self):
|
| 516 |
+
return self._build_dataloader(self.test_dataset, shuffle=False)
|
| 517 |
+
|
| 518 |
+
def _build_dataloader(self, dataset: torch.utils.data.Dataset, shuffle=True):
|
| 519 |
+
return torch.utils.data.DataLoader(
|
| 520 |
+
dataset,
|
| 521 |
+
shuffle=shuffle,
|
| 522 |
+
batch_size=self.batch_size,
|
| 523 |
+
num_workers=self.num_workers,
|
| 524 |
+
)
|
| 525 |
+
|
| 526 |
+
def on_before_batch_transfer(self, batch, *args):
|
| 527 |
+
images, targets = batch
|
| 528 |
+
batch = {"image": images, "labels": targets}
|
| 529 |
+
return batch
|
multimodal/examples/flava/data/imagenet_zeroshot_data.py
ADDED
|
@@ -0,0 +1,1095 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# File taken from https://github.com/mlfoundations/open_clip/
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
imagenet_classnames = [
|
| 11 |
+
"tench",
|
| 12 |
+
"goldfish",
|
| 13 |
+
"great white shark",
|
| 14 |
+
"tiger shark",
|
| 15 |
+
"hammerhead shark",
|
| 16 |
+
"electric ray",
|
| 17 |
+
"stingray",
|
| 18 |
+
"rooster",
|
| 19 |
+
"hen",
|
| 20 |
+
"ostrich",
|
| 21 |
+
"brambling",
|
| 22 |
+
"goldfinch",
|
| 23 |
+
"house finch",
|
| 24 |
+
"junco",
|
| 25 |
+
"indigo bunting",
|
| 26 |
+
"American robin",
|
| 27 |
+
"bulbul",
|
| 28 |
+
"jay",
|
| 29 |
+
"magpie",
|
| 30 |
+
"chickadee",
|
| 31 |
+
"American dipper",
|
| 32 |
+
"kite (bird of prey)",
|
| 33 |
+
"bald eagle",
|
| 34 |
+
"vulture",
|
| 35 |
+
"great grey owl",
|
| 36 |
+
"fire salamander",
|
| 37 |
+
"smooth newt",
|
| 38 |
+
"newt",
|
| 39 |
+
"spotted salamander",
|
| 40 |
+
"axolotl",
|
| 41 |
+
"American bullfrog",
|
| 42 |
+
"tree frog",
|
| 43 |
+
"tailed frog",
|
| 44 |
+
"loggerhead sea turtle",
|
| 45 |
+
"leatherback sea turtle",
|
| 46 |
+
"mud turtle",
|
| 47 |
+
"terrapin",
|
| 48 |
+
"box turtle",
|
| 49 |
+
"banded gecko",
|
| 50 |
+
"green iguana",
|
| 51 |
+
"Carolina anole",
|
| 52 |
+
"desert grassland whiptail lizard",
|
| 53 |
+
"agama",
|
| 54 |
+
"frilled-necked lizard",
|
| 55 |
+
"alligator lizard",
|
| 56 |
+
"Gila monster",
|
| 57 |
+
"European green lizard",
|
| 58 |
+
"chameleon",
|
| 59 |
+
"Komodo dragon",
|
| 60 |
+
"Nile crocodile",
|
| 61 |
+
"American alligator",
|
| 62 |
+
"triceratops",
|
| 63 |
+
"worm snake",
|
| 64 |
+
"ring-necked snake",
|
| 65 |
+
"eastern hog-nosed snake",
|
| 66 |
+
"smooth green snake",
|
| 67 |
+
"kingsnake",
|
| 68 |
+
"garter snake",
|
| 69 |
+
"water snake",
|
| 70 |
+
"vine snake",
|
| 71 |
+
"night snake",
|
| 72 |
+
"boa constrictor",
|
| 73 |
+
"African rock python",
|
| 74 |
+
"Indian cobra",
|
| 75 |
+
"green mamba",
|
| 76 |
+
"sea snake",
|
| 77 |
+
"Saharan horned viper",
|
| 78 |
+
"eastern diamondback rattlesnake",
|
| 79 |
+
"sidewinder rattlesnake",
|
| 80 |
+
"trilobite",
|
| 81 |
+
"harvestman",
|
| 82 |
+
"scorpion",
|
| 83 |
+
"yellow garden spider",
|
| 84 |
+
"barn spider",
|
| 85 |
+
"European garden spider",
|
| 86 |
+
"southern black widow",
|
| 87 |
+
"tarantula",
|
| 88 |
+
"wolf spider",
|
| 89 |
+
"tick",
|
| 90 |
+
"centipede",
|
| 91 |
+
"black grouse",
|
| 92 |
+
"ptarmigan",
|
| 93 |
+
"ruffed grouse",
|
| 94 |
+
"prairie grouse",
|
| 95 |
+
"peafowl",
|
| 96 |
+
"quail",
|
| 97 |
+
"partridge",
|
| 98 |
+
"african grey parrot",
|
| 99 |
+
"macaw",
|
| 100 |
+
"sulphur-crested cockatoo",
|
| 101 |
+
"lorikeet",
|
| 102 |
+
"coucal",
|
| 103 |
+
"bee eater",
|
| 104 |
+
"hornbill",
|
| 105 |
+
"hummingbird",
|
| 106 |
+
"jacamar",
|
| 107 |
+
"toucan",
|
| 108 |
+
"duck",
|
| 109 |
+
"red-breasted merganser",
|
| 110 |
+
"goose",
|
| 111 |
+
"black swan",
|
| 112 |
+
"tusker",
|
| 113 |
+
"echidna",
|
| 114 |
+
"platypus",
|
| 115 |
+
"wallaby",
|
| 116 |
+
"koala",
|
| 117 |
+
"wombat",
|
| 118 |
+
"jellyfish",
|
| 119 |
+
"sea anemone",
|
| 120 |
+
"brain coral",
|
| 121 |
+
"flatworm",
|
| 122 |
+
"nematode",
|
| 123 |
+
"conch",
|
| 124 |
+
"snail",
|
| 125 |
+
"slug",
|
| 126 |
+
"sea slug",
|
| 127 |
+
"chiton",
|
| 128 |
+
"chambered nautilus",
|
| 129 |
+
"Dungeness crab",
|
| 130 |
+
"rock crab",
|
| 131 |
+
"fiddler crab",
|
| 132 |
+
"red king crab",
|
| 133 |
+
"American lobster",
|
| 134 |
+
"spiny lobster",
|
| 135 |
+
"crayfish",
|
| 136 |
+
"hermit crab",
|
| 137 |
+
"isopod",
|
| 138 |
+
"white stork",
|
| 139 |
+
"black stork",
|
| 140 |
+
"spoonbill",
|
| 141 |
+
"flamingo",
|
| 142 |
+
"little blue heron",
|
| 143 |
+
"great egret",
|
| 144 |
+
"bittern bird",
|
| 145 |
+
"crane bird",
|
| 146 |
+
"limpkin",
|
| 147 |
+
"common gallinule",
|
| 148 |
+
"American coot",
|
| 149 |
+
"bustard",
|
| 150 |
+
"ruddy turnstone",
|
| 151 |
+
"dunlin",
|
| 152 |
+
"common redshank",
|
| 153 |
+
"dowitcher",
|
| 154 |
+
"oystercatcher",
|
| 155 |
+
"pelican",
|
| 156 |
+
"king penguin",
|
| 157 |
+
"albatross",
|
| 158 |
+
"grey whale",
|
| 159 |
+
"killer whale",
|
| 160 |
+
"dugong",
|
| 161 |
+
"sea lion",
|
| 162 |
+
"Chihuahua",
|
| 163 |
+
"Japanese Chin",
|
| 164 |
+
"Maltese",
|
| 165 |
+
"Pekingese",
|
| 166 |
+
"Shih Tzu",
|
| 167 |
+
"King Charles Spaniel",
|
| 168 |
+
"Papillon",
|
| 169 |
+
"toy terrier",
|
| 170 |
+
"Rhodesian Ridgeback",
|
| 171 |
+
"Afghan Hound",
|
| 172 |
+
"Basset Hound",
|
| 173 |
+
"Beagle",
|
| 174 |
+
"Bloodhound",
|
| 175 |
+
"Bluetick Coonhound",
|
| 176 |
+
"Black and Tan Coonhound",
|
| 177 |
+
"Treeing Walker Coonhound",
|
| 178 |
+
"English foxhound",
|
| 179 |
+
"Redbone Coonhound",
|
| 180 |
+
"borzoi",
|
| 181 |
+
"Irish Wolfhound",
|
| 182 |
+
"Italian Greyhound",
|
| 183 |
+
"Whippet",
|
| 184 |
+
"Ibizan Hound",
|
| 185 |
+
"Norwegian Elkhound",
|
| 186 |
+
"Otterhound",
|
| 187 |
+
"Saluki",
|
| 188 |
+
"Scottish Deerhound",
|
| 189 |
+
"Weimaraner",
|
| 190 |
+
"Staffordshire Bull Terrier",
|
| 191 |
+
"American Staffordshire Terrier",
|
| 192 |
+
"Bedlington Terrier",
|
| 193 |
+
"Border Terrier",
|
| 194 |
+
"Kerry Blue Terrier",
|
| 195 |
+
"Irish Terrier",
|
| 196 |
+
"Norfolk Terrier",
|
| 197 |
+
"Norwich Terrier",
|
| 198 |
+
"Yorkshire Terrier",
|
| 199 |
+
"Wire Fox Terrier",
|
| 200 |
+
"Lakeland Terrier",
|
| 201 |
+
"Sealyham Terrier",
|
| 202 |
+
"Airedale Terrier",
|
| 203 |
+
"Cairn Terrier",
|
| 204 |
+
"Australian Terrier",
|
| 205 |
+
"Dandie Dinmont Terrier",
|
| 206 |
+
"Boston Terrier",
|
| 207 |
+
"Miniature Schnauzer",
|
| 208 |
+
"Giant Schnauzer",
|
| 209 |
+
"Standard Schnauzer",
|
| 210 |
+
"Scottish Terrier",
|
| 211 |
+
"Tibetan Terrier",
|
| 212 |
+
"Australian Silky Terrier",
|
| 213 |
+
"Soft-coated Wheaten Terrier",
|
| 214 |
+
"West Highland White Terrier",
|
| 215 |
+
"Lhasa Apso",
|
| 216 |
+
"Flat-Coated Retriever",
|
| 217 |
+
"Curly-coated Retriever",
|
| 218 |
+
"Golden Retriever",
|
| 219 |
+
"Labrador Retriever",
|
| 220 |
+
"Chesapeake Bay Retriever",
|
| 221 |
+
"German Shorthaired Pointer",
|
| 222 |
+
"Vizsla",
|
| 223 |
+
"English Setter",
|
| 224 |
+
"Irish Setter",
|
| 225 |
+
"Gordon Setter",
|
| 226 |
+
"Brittany dog",
|
| 227 |
+
"Clumber Spaniel",
|
| 228 |
+
"English Springer Spaniel",
|
| 229 |
+
"Welsh Springer Spaniel",
|
| 230 |
+
"Cocker Spaniel",
|
| 231 |
+
"Sussex Spaniel",
|
| 232 |
+
"Irish Water Spaniel",
|
| 233 |
+
"Kuvasz",
|
| 234 |
+
"Schipperke",
|
| 235 |
+
"Groenendael dog",
|
| 236 |
+
"Malinois",
|
| 237 |
+
"Briard",
|
| 238 |
+
"Australian Kelpie",
|
| 239 |
+
"Komondor",
|
| 240 |
+
"Old English Sheepdog",
|
| 241 |
+
"Shetland Sheepdog",
|
| 242 |
+
"collie",
|
| 243 |
+
"Border Collie",
|
| 244 |
+
"Bouvier des Flandres dog",
|
| 245 |
+
"Rottweiler",
|
| 246 |
+
"German Shepherd Dog",
|
| 247 |
+
"Dobermann",
|
| 248 |
+
"Miniature Pinscher",
|
| 249 |
+
"Greater Swiss Mountain Dog",
|
| 250 |
+
"Bernese Mountain Dog",
|
| 251 |
+
"Appenzeller Sennenhund",
|
| 252 |
+
"Entlebucher Sennenhund",
|
| 253 |
+
"Boxer",
|
| 254 |
+
"Bullmastiff",
|
| 255 |
+
"Tibetan Mastiff",
|
| 256 |
+
"French Bulldog",
|
| 257 |
+
"Great Dane",
|
| 258 |
+
"St. Bernard",
|
| 259 |
+
"husky",
|
| 260 |
+
"Alaskan Malamute",
|
| 261 |
+
"Siberian Husky",
|
| 262 |
+
"Dalmatian",
|
| 263 |
+
"Affenpinscher",
|
| 264 |
+
"Basenji",
|
| 265 |
+
"pug",
|
| 266 |
+
"Leonberger",
|
| 267 |
+
"Newfoundland dog",
|
| 268 |
+
"Great Pyrenees dog",
|
| 269 |
+
"Samoyed",
|
| 270 |
+
"Pomeranian",
|
| 271 |
+
"Chow Chow",
|
| 272 |
+
"Keeshond",
|
| 273 |
+
"brussels griffon",
|
| 274 |
+
"Pembroke Welsh Corgi",
|
| 275 |
+
"Cardigan Welsh Corgi",
|
| 276 |
+
"Toy Poodle",
|
| 277 |
+
"Miniature Poodle",
|
| 278 |
+
"Standard Poodle",
|
| 279 |
+
"Mexican hairless dog (xoloitzcuintli)",
|
| 280 |
+
"grey wolf",
|
| 281 |
+
"Alaskan tundra wolf",
|
| 282 |
+
"red wolf or maned wolf",
|
| 283 |
+
"coyote",
|
| 284 |
+
"dingo",
|
| 285 |
+
"dhole",
|
| 286 |
+
"African wild dog",
|
| 287 |
+
"hyena",
|
| 288 |
+
"red fox",
|
| 289 |
+
"kit fox",
|
| 290 |
+
"Arctic fox",
|
| 291 |
+
"grey fox",
|
| 292 |
+
"tabby cat",
|
| 293 |
+
"tiger cat",
|
| 294 |
+
"Persian cat",
|
| 295 |
+
"Siamese cat",
|
| 296 |
+
"Egyptian Mau",
|
| 297 |
+
"cougar",
|
| 298 |
+
"lynx",
|
| 299 |
+
"leopard",
|
| 300 |
+
"snow leopard",
|
| 301 |
+
"jaguar",
|
| 302 |
+
"lion",
|
| 303 |
+
"tiger",
|
| 304 |
+
"cheetah",
|
| 305 |
+
"brown bear",
|
| 306 |
+
"American black bear",
|
| 307 |
+
"polar bear",
|
| 308 |
+
"sloth bear",
|
| 309 |
+
"mongoose",
|
| 310 |
+
"meerkat",
|
| 311 |
+
"tiger beetle",
|
| 312 |
+
"ladybug",
|
| 313 |
+
"ground beetle",
|
| 314 |
+
"longhorn beetle",
|
| 315 |
+
"leaf beetle",
|
| 316 |
+
"dung beetle",
|
| 317 |
+
"rhinoceros beetle",
|
| 318 |
+
"weevil",
|
| 319 |
+
"fly",
|
| 320 |
+
"bee",
|
| 321 |
+
"ant",
|
| 322 |
+
"grasshopper",
|
| 323 |
+
"cricket insect",
|
| 324 |
+
"stick insect",
|
| 325 |
+
"cockroach",
|
| 326 |
+
"praying mantis",
|
| 327 |
+
"cicada",
|
| 328 |
+
"leafhopper",
|
| 329 |
+
"lacewing",
|
| 330 |
+
"dragonfly",
|
| 331 |
+
"damselfly",
|
| 332 |
+
"red admiral butterfly",
|
| 333 |
+
"ringlet butterfly",
|
| 334 |
+
"monarch butterfly",
|
| 335 |
+
"small white butterfly",
|
| 336 |
+
"sulphur butterfly",
|
| 337 |
+
"gossamer-winged butterfly",
|
| 338 |
+
"starfish",
|
| 339 |
+
"sea urchin",
|
| 340 |
+
"sea cucumber",
|
| 341 |
+
"cottontail rabbit",
|
| 342 |
+
"hare",
|
| 343 |
+
"Angora rabbit",
|
| 344 |
+
"hamster",
|
| 345 |
+
"porcupine",
|
| 346 |
+
"fox squirrel",
|
| 347 |
+
"marmot",
|
| 348 |
+
"beaver",
|
| 349 |
+
"guinea pig",
|
| 350 |
+
"common sorrel horse",
|
| 351 |
+
"zebra",
|
| 352 |
+
"pig",
|
| 353 |
+
"wild boar",
|
| 354 |
+
"warthog",
|
| 355 |
+
"hippopotamus",
|
| 356 |
+
"ox",
|
| 357 |
+
"water buffalo",
|
| 358 |
+
"bison",
|
| 359 |
+
"ram (adult male sheep)",
|
| 360 |
+
"bighorn sheep",
|
| 361 |
+
"Alpine ibex",
|
| 362 |
+
"hartebeest",
|
| 363 |
+
"impala (antelope)",
|
| 364 |
+
"gazelle",
|
| 365 |
+
"arabian camel",
|
| 366 |
+
"llama",
|
| 367 |
+
"weasel",
|
| 368 |
+
"mink",
|
| 369 |
+
"European polecat",
|
| 370 |
+
"black-footed ferret",
|
| 371 |
+
"otter",
|
| 372 |
+
"skunk",
|
| 373 |
+
"badger",
|
| 374 |
+
"armadillo",
|
| 375 |
+
"three-toed sloth",
|
| 376 |
+
"orangutan",
|
| 377 |
+
"gorilla",
|
| 378 |
+
"chimpanzee",
|
| 379 |
+
"gibbon",
|
| 380 |
+
"siamang",
|
| 381 |
+
"guenon",
|
| 382 |
+
"patas monkey",
|
| 383 |
+
"baboon",
|
| 384 |
+
"macaque",
|
| 385 |
+
"langur",
|
| 386 |
+
"black-and-white colobus",
|
| 387 |
+
"proboscis monkey",
|
| 388 |
+
"marmoset",
|
| 389 |
+
"white-headed capuchin",
|
| 390 |
+
"howler monkey",
|
| 391 |
+
"titi monkey",
|
| 392 |
+
"Geoffroy's spider monkey",
|
| 393 |
+
"common squirrel monkey",
|
| 394 |
+
"ring-tailed lemur",
|
| 395 |
+
"indri",
|
| 396 |
+
"Asian elephant",
|
| 397 |
+
"African bush elephant",
|
| 398 |
+
"red panda",
|
| 399 |
+
"giant panda",
|
| 400 |
+
"snoek fish",
|
| 401 |
+
"eel",
|
| 402 |
+
"silver salmon",
|
| 403 |
+
"rock beauty fish",
|
| 404 |
+
"clownfish",
|
| 405 |
+
"sturgeon",
|
| 406 |
+
"gar fish",
|
| 407 |
+
"lionfish",
|
| 408 |
+
"pufferfish",
|
| 409 |
+
"abacus",
|
| 410 |
+
"abaya",
|
| 411 |
+
"academic gown",
|
| 412 |
+
"accordion",
|
| 413 |
+
"acoustic guitar",
|
| 414 |
+
"aircraft carrier",
|
| 415 |
+
"airliner",
|
| 416 |
+
"airship",
|
| 417 |
+
"altar",
|
| 418 |
+
"ambulance",
|
| 419 |
+
"amphibious vehicle",
|
| 420 |
+
"analog clock",
|
| 421 |
+
"apiary",
|
| 422 |
+
"apron",
|
| 423 |
+
"trash can",
|
| 424 |
+
"assault rifle",
|
| 425 |
+
"backpack",
|
| 426 |
+
"bakery",
|
| 427 |
+
"balance beam",
|
| 428 |
+
"balloon",
|
| 429 |
+
"ballpoint pen",
|
| 430 |
+
"Band-Aid",
|
| 431 |
+
"banjo",
|
| 432 |
+
"baluster / handrail",
|
| 433 |
+
"barbell",
|
| 434 |
+
"barber chair",
|
| 435 |
+
"barbershop",
|
| 436 |
+
"barn",
|
| 437 |
+
"barometer",
|
| 438 |
+
"barrel",
|
| 439 |
+
"wheelbarrow",
|
| 440 |
+
"baseball",
|
| 441 |
+
"basketball",
|
| 442 |
+
"bassinet",
|
| 443 |
+
"bassoon",
|
| 444 |
+
"swimming cap",
|
| 445 |
+
"bath towel",
|
| 446 |
+
"bathtub",
|
| 447 |
+
"station wagon",
|
| 448 |
+
"lighthouse",
|
| 449 |
+
"beaker",
|
| 450 |
+
"military hat (bearskin or shako)",
|
| 451 |
+
"beer bottle",
|
| 452 |
+
"beer glass",
|
| 453 |
+
"bell tower",
|
| 454 |
+
"baby bib",
|
| 455 |
+
"tandem bicycle",
|
| 456 |
+
"bikini",
|
| 457 |
+
"ring binder",
|
| 458 |
+
"binoculars",
|
| 459 |
+
"birdhouse",
|
| 460 |
+
"boathouse",
|
| 461 |
+
"bobsleigh",
|
| 462 |
+
"bolo tie",
|
| 463 |
+
"poke bonnet",
|
| 464 |
+
"bookcase",
|
| 465 |
+
"bookstore",
|
| 466 |
+
"bottle cap",
|
| 467 |
+
"hunting bow",
|
| 468 |
+
"bow tie",
|
| 469 |
+
"brass memorial plaque",
|
| 470 |
+
"bra",
|
| 471 |
+
"breakwater",
|
| 472 |
+
"breastplate",
|
| 473 |
+
"broom",
|
| 474 |
+
"bucket",
|
| 475 |
+
"buckle",
|
| 476 |
+
"bulletproof vest",
|
| 477 |
+
"high-speed train",
|
| 478 |
+
"butcher shop",
|
| 479 |
+
"taxicab",
|
| 480 |
+
"cauldron",
|
| 481 |
+
"candle",
|
| 482 |
+
"cannon",
|
| 483 |
+
"canoe",
|
| 484 |
+
"can opener",
|
| 485 |
+
"cardigan",
|
| 486 |
+
"car mirror",
|
| 487 |
+
"carousel",
|
| 488 |
+
"tool kit",
|
| 489 |
+
"cardboard box / carton",
|
| 490 |
+
"car wheel",
|
| 491 |
+
"automated teller machine",
|
| 492 |
+
"cassette",
|
| 493 |
+
"cassette player",
|
| 494 |
+
"castle",
|
| 495 |
+
"catamaran",
|
| 496 |
+
"CD player",
|
| 497 |
+
"cello",
|
| 498 |
+
"mobile phone",
|
| 499 |
+
"chain",
|
| 500 |
+
"chain-link fence",
|
| 501 |
+
"chain mail",
|
| 502 |
+
"chainsaw",
|
| 503 |
+
"storage chest",
|
| 504 |
+
"chiffonier",
|
| 505 |
+
"bell or wind chime",
|
| 506 |
+
"china cabinet",
|
| 507 |
+
"Christmas stocking",
|
| 508 |
+
"church",
|
| 509 |
+
"movie theater",
|
| 510 |
+
"cleaver",
|
| 511 |
+
"cliff dwelling",
|
| 512 |
+
"cloak",
|
| 513 |
+
"clogs",
|
| 514 |
+
"cocktail shaker",
|
| 515 |
+
"coffee mug",
|
| 516 |
+
"coffeemaker",
|
| 517 |
+
"spiral or coil",
|
| 518 |
+
"combination lock",
|
| 519 |
+
"computer keyboard",
|
| 520 |
+
"candy store",
|
| 521 |
+
"container ship",
|
| 522 |
+
"convertible",
|
| 523 |
+
"corkscrew",
|
| 524 |
+
"cornet",
|
| 525 |
+
"cowboy boot",
|
| 526 |
+
"cowboy hat",
|
| 527 |
+
"cradle",
|
| 528 |
+
"construction crane",
|
| 529 |
+
"crash helmet",
|
| 530 |
+
"crate",
|
| 531 |
+
"infant bed",
|
| 532 |
+
"Crock Pot",
|
| 533 |
+
"croquet ball",
|
| 534 |
+
"crutch",
|
| 535 |
+
"cuirass",
|
| 536 |
+
"dam",
|
| 537 |
+
"desk",
|
| 538 |
+
"desktop computer",
|
| 539 |
+
"rotary dial telephone",
|
| 540 |
+
"diaper",
|
| 541 |
+
"digital clock",
|
| 542 |
+
"digital watch",
|
| 543 |
+
"dining table",
|
| 544 |
+
"dishcloth",
|
| 545 |
+
"dishwasher",
|
| 546 |
+
"disc brake",
|
| 547 |
+
"dock",
|
| 548 |
+
"dog sled",
|
| 549 |
+
"dome",
|
| 550 |
+
"doormat",
|
| 551 |
+
"drilling rig",
|
| 552 |
+
"drum",
|
| 553 |
+
"drumstick",
|
| 554 |
+
"dumbbell",
|
| 555 |
+
"Dutch oven",
|
| 556 |
+
"electric fan",
|
| 557 |
+
"electric guitar",
|
| 558 |
+
"electric locomotive",
|
| 559 |
+
"entertainment center",
|
| 560 |
+
"envelope",
|
| 561 |
+
"espresso machine",
|
| 562 |
+
"face powder",
|
| 563 |
+
"feather boa",
|
| 564 |
+
"filing cabinet",
|
| 565 |
+
"fireboat",
|
| 566 |
+
"fire truck",
|
| 567 |
+
"fire screen",
|
| 568 |
+
"flagpole",
|
| 569 |
+
"flute",
|
| 570 |
+
"folding chair",
|
| 571 |
+
"football helmet",
|
| 572 |
+
"forklift",
|
| 573 |
+
"fountain",
|
| 574 |
+
"fountain pen",
|
| 575 |
+
"four-poster bed",
|
| 576 |
+
"freight car",
|
| 577 |
+
"French horn",
|
| 578 |
+
"frying pan",
|
| 579 |
+
"fur coat",
|
| 580 |
+
"garbage truck",
|
| 581 |
+
"gas mask or respirator",
|
| 582 |
+
"gas pump",
|
| 583 |
+
"goblet",
|
| 584 |
+
"go-kart",
|
| 585 |
+
"golf ball",
|
| 586 |
+
"golf cart",
|
| 587 |
+
"gondola",
|
| 588 |
+
"gong",
|
| 589 |
+
"gown",
|
| 590 |
+
"grand piano",
|
| 591 |
+
"greenhouse",
|
| 592 |
+
"radiator grille",
|
| 593 |
+
"grocery store",
|
| 594 |
+
"guillotine",
|
| 595 |
+
"hair clip",
|
| 596 |
+
"hair spray",
|
| 597 |
+
"half-track",
|
| 598 |
+
"hammer",
|
| 599 |
+
"hamper",
|
| 600 |
+
"hair dryer",
|
| 601 |
+
"hand-held computer",
|
| 602 |
+
"handkerchief",
|
| 603 |
+
"hard disk drive",
|
| 604 |
+
"harmonica",
|
| 605 |
+
"harp",
|
| 606 |
+
"combine harvester",
|
| 607 |
+
"hatchet",
|
| 608 |
+
"holster",
|
| 609 |
+
"home theater",
|
| 610 |
+
"honeycomb",
|
| 611 |
+
"hook",
|
| 612 |
+
"hoop skirt",
|
| 613 |
+
"gymnastic horizontal bar",
|
| 614 |
+
"horse-drawn vehicle",
|
| 615 |
+
"hourglass",
|
| 616 |
+
"iPod",
|
| 617 |
+
"clothes iron",
|
| 618 |
+
"carved pumpkin",
|
| 619 |
+
"jeans",
|
| 620 |
+
"jeep",
|
| 621 |
+
"T-shirt",
|
| 622 |
+
"jigsaw puzzle",
|
| 623 |
+
"rickshaw",
|
| 624 |
+
"joystick",
|
| 625 |
+
"kimono",
|
| 626 |
+
"knee pad",
|
| 627 |
+
"knot",
|
| 628 |
+
"lab coat",
|
| 629 |
+
"ladle",
|
| 630 |
+
"lampshade",
|
| 631 |
+
"laptop computer",
|
| 632 |
+
"lawn mower",
|
| 633 |
+
"lens cap",
|
| 634 |
+
"letter opener",
|
| 635 |
+
"library",
|
| 636 |
+
"lifeboat",
|
| 637 |
+
"lighter",
|
| 638 |
+
"limousine",
|
| 639 |
+
"ocean liner",
|
| 640 |
+
"lipstick",
|
| 641 |
+
"slip-on shoe",
|
| 642 |
+
"lotion",
|
| 643 |
+
"music speaker",
|
| 644 |
+
"loupe magnifying glass",
|
| 645 |
+
"sawmill",
|
| 646 |
+
"magnetic compass",
|
| 647 |
+
"messenger bag",
|
| 648 |
+
"mailbox",
|
| 649 |
+
"tights",
|
| 650 |
+
"one-piece bathing suit",
|
| 651 |
+
"manhole cover",
|
| 652 |
+
"maraca",
|
| 653 |
+
"marimba",
|
| 654 |
+
"mask",
|
| 655 |
+
"matchstick",
|
| 656 |
+
"maypole",
|
| 657 |
+
"maze",
|
| 658 |
+
"measuring cup",
|
| 659 |
+
"medicine cabinet",
|
| 660 |
+
"megalith",
|
| 661 |
+
"microphone",
|
| 662 |
+
"microwave oven",
|
| 663 |
+
"military uniform",
|
| 664 |
+
"milk can",
|
| 665 |
+
"minibus",
|
| 666 |
+
"miniskirt",
|
| 667 |
+
"minivan",
|
| 668 |
+
"missile",
|
| 669 |
+
"mitten",
|
| 670 |
+
"mixing bowl",
|
| 671 |
+
"mobile home",
|
| 672 |
+
"ford model t",
|
| 673 |
+
"modem",
|
| 674 |
+
"monastery",
|
| 675 |
+
"monitor",
|
| 676 |
+
"moped",
|
| 677 |
+
"mortar and pestle",
|
| 678 |
+
"graduation cap",
|
| 679 |
+
"mosque",
|
| 680 |
+
"mosquito net",
|
| 681 |
+
"vespa",
|
| 682 |
+
"mountain bike",
|
| 683 |
+
"tent",
|
| 684 |
+
"computer mouse",
|
| 685 |
+
"mousetrap",
|
| 686 |
+
"moving van",
|
| 687 |
+
"muzzle",
|
| 688 |
+
"metal nail",
|
| 689 |
+
"neck brace",
|
| 690 |
+
"necklace",
|
| 691 |
+
"baby pacifier",
|
| 692 |
+
"notebook computer",
|
| 693 |
+
"obelisk",
|
| 694 |
+
"oboe",
|
| 695 |
+
"ocarina",
|
| 696 |
+
"odometer",
|
| 697 |
+
"oil filter",
|
| 698 |
+
"pipe organ",
|
| 699 |
+
"oscilloscope",
|
| 700 |
+
"overskirt",
|
| 701 |
+
"bullock cart",
|
| 702 |
+
"oxygen mask",
|
| 703 |
+
"product packet / packaging",
|
| 704 |
+
"paddle",
|
| 705 |
+
"paddle wheel",
|
| 706 |
+
"padlock",
|
| 707 |
+
"paintbrush",
|
| 708 |
+
"pajamas",
|
| 709 |
+
"palace",
|
| 710 |
+
"pan flute",
|
| 711 |
+
"paper towel",
|
| 712 |
+
"parachute",
|
| 713 |
+
"parallel bars",
|
| 714 |
+
"park bench",
|
| 715 |
+
"parking meter",
|
| 716 |
+
"railroad car",
|
| 717 |
+
"patio",
|
| 718 |
+
"payphone",
|
| 719 |
+
"pedestal",
|
| 720 |
+
"pencil case",
|
| 721 |
+
"pencil sharpener",
|
| 722 |
+
"perfume",
|
| 723 |
+
"Petri dish",
|
| 724 |
+
"photocopier",
|
| 725 |
+
"plectrum",
|
| 726 |
+
"Pickelhaube",
|
| 727 |
+
"picket fence",
|
| 728 |
+
"pickup truck",
|
| 729 |
+
"pier",
|
| 730 |
+
"piggy bank",
|
| 731 |
+
"pill bottle",
|
| 732 |
+
"pillow",
|
| 733 |
+
"ping-pong ball",
|
| 734 |
+
"pinwheel",
|
| 735 |
+
"pirate ship",
|
| 736 |
+
"drink pitcher",
|
| 737 |
+
"block plane",
|
| 738 |
+
"planetarium",
|
| 739 |
+
"plastic bag",
|
| 740 |
+
"plate rack",
|
| 741 |
+
"farm plow",
|
| 742 |
+
"plunger",
|
| 743 |
+
"Polaroid camera",
|
| 744 |
+
"pole",
|
| 745 |
+
"police van",
|
| 746 |
+
"poncho",
|
| 747 |
+
"pool table",
|
| 748 |
+
"soda bottle",
|
| 749 |
+
"plant pot",
|
| 750 |
+
"potter's wheel",
|
| 751 |
+
"power drill",
|
| 752 |
+
"prayer rug",
|
| 753 |
+
"printer",
|
| 754 |
+
"prison",
|
| 755 |
+
"missile",
|
| 756 |
+
"projector",
|
| 757 |
+
"hockey puck",
|
| 758 |
+
"punching bag",
|
| 759 |
+
"purse",
|
| 760 |
+
"quill",
|
| 761 |
+
"quilt",
|
| 762 |
+
"race car",
|
| 763 |
+
"racket",
|
| 764 |
+
"radiator",
|
| 765 |
+
"radio",
|
| 766 |
+
"radio telescope",
|
| 767 |
+
"rain barrel",
|
| 768 |
+
"recreational vehicle",
|
| 769 |
+
"fishing casting reel",
|
| 770 |
+
"reflex camera",
|
| 771 |
+
"refrigerator",
|
| 772 |
+
"remote control",
|
| 773 |
+
"restaurant",
|
| 774 |
+
"revolver",
|
| 775 |
+
"rifle",
|
| 776 |
+
"rocking chair",
|
| 777 |
+
"rotisserie",
|
| 778 |
+
"eraser",
|
| 779 |
+
"rugby ball",
|
| 780 |
+
"ruler measuring stick",
|
| 781 |
+
"sneaker",
|
| 782 |
+
"safe",
|
| 783 |
+
"safety pin",
|
| 784 |
+
"salt shaker",
|
| 785 |
+
"sandal",
|
| 786 |
+
"sarong",
|
| 787 |
+
"saxophone",
|
| 788 |
+
"scabbard",
|
| 789 |
+
"weighing scale",
|
| 790 |
+
"school bus",
|
| 791 |
+
"schooner",
|
| 792 |
+
"scoreboard",
|
| 793 |
+
"CRT monitor",
|
| 794 |
+
"screw",
|
| 795 |
+
"screwdriver",
|
| 796 |
+
"seat belt",
|
| 797 |
+
"sewing machine",
|
| 798 |
+
"shield",
|
| 799 |
+
"shoe store",
|
| 800 |
+
"shoji screen / room divider",
|
| 801 |
+
"shopping basket",
|
| 802 |
+
"shopping cart",
|
| 803 |
+
"shovel",
|
| 804 |
+
"shower cap",
|
| 805 |
+
"shower curtain",
|
| 806 |
+
"ski",
|
| 807 |
+
"balaclava ski mask",
|
| 808 |
+
"sleeping bag",
|
| 809 |
+
"slide rule",
|
| 810 |
+
"sliding door",
|
| 811 |
+
"slot machine",
|
| 812 |
+
"snorkel",
|
| 813 |
+
"snowmobile",
|
| 814 |
+
"snowplow",
|
| 815 |
+
"soap dispenser",
|
| 816 |
+
"soccer ball",
|
| 817 |
+
"sock",
|
| 818 |
+
"solar thermal collector",
|
| 819 |
+
"sombrero",
|
| 820 |
+
"soup bowl",
|
| 821 |
+
"keyboard space bar",
|
| 822 |
+
"space heater",
|
| 823 |
+
"space shuttle",
|
| 824 |
+
"spatula",
|
| 825 |
+
"motorboat",
|
| 826 |
+
"spider web",
|
| 827 |
+
"spindle",
|
| 828 |
+
"sports car",
|
| 829 |
+
"spotlight",
|
| 830 |
+
"stage",
|
| 831 |
+
"steam locomotive",
|
| 832 |
+
"through arch bridge",
|
| 833 |
+
"steel drum",
|
| 834 |
+
"stethoscope",
|
| 835 |
+
"scarf",
|
| 836 |
+
"stone wall",
|
| 837 |
+
"stopwatch",
|
| 838 |
+
"stove",
|
| 839 |
+
"strainer",
|
| 840 |
+
"tram",
|
| 841 |
+
"stretcher",
|
| 842 |
+
"couch",
|
| 843 |
+
"stupa",
|
| 844 |
+
"submarine",
|
| 845 |
+
"suit",
|
| 846 |
+
"sundial",
|
| 847 |
+
"sunglasses",
|
| 848 |
+
"sunglasses",
|
| 849 |
+
"sunscreen",
|
| 850 |
+
"suspension bridge",
|
| 851 |
+
"mop",
|
| 852 |
+
"sweatshirt",
|
| 853 |
+
"swim trunks / shorts",
|
| 854 |
+
"swing",
|
| 855 |
+
"electrical switch",
|
| 856 |
+
"syringe",
|
| 857 |
+
"table lamp",
|
| 858 |
+
"tank",
|
| 859 |
+
"tape player",
|
| 860 |
+
"teapot",
|
| 861 |
+
"teddy bear",
|
| 862 |
+
"television",
|
| 863 |
+
"tennis ball",
|
| 864 |
+
"thatched roof",
|
| 865 |
+
"front curtain",
|
| 866 |
+
"thimble",
|
| 867 |
+
"threshing machine",
|
| 868 |
+
"throne",
|
| 869 |
+
"tile roof",
|
| 870 |
+
"toaster",
|
| 871 |
+
"tobacco shop",
|
| 872 |
+
"toilet seat",
|
| 873 |
+
"torch",
|
| 874 |
+
"totem pole",
|
| 875 |
+
"tow truck",
|
| 876 |
+
"toy store",
|
| 877 |
+
"tractor",
|
| 878 |
+
"semi-trailer truck",
|
| 879 |
+
"tray",
|
| 880 |
+
"trench coat",
|
| 881 |
+
"tricycle",
|
| 882 |
+
"trimaran",
|
| 883 |
+
"tripod",
|
| 884 |
+
"triumphal arch",
|
| 885 |
+
"trolleybus",
|
| 886 |
+
"trombone",
|
| 887 |
+
"hot tub",
|
| 888 |
+
"turnstile",
|
| 889 |
+
"typewriter keyboard",
|
| 890 |
+
"umbrella",
|
| 891 |
+
"unicycle",
|
| 892 |
+
"upright piano",
|
| 893 |
+
"vacuum cleaner",
|
| 894 |
+
"vase",
|
| 895 |
+
"vaulted or arched ceiling",
|
| 896 |
+
"velvet fabric",
|
| 897 |
+
"vending machine",
|
| 898 |
+
"vestment",
|
| 899 |
+
"viaduct",
|
| 900 |
+
"violin",
|
| 901 |
+
"volleyball",
|
| 902 |
+
"waffle iron",
|
| 903 |
+
"wall clock",
|
| 904 |
+
"wallet",
|
| 905 |
+
"wardrobe",
|
| 906 |
+
"military aircraft",
|
| 907 |
+
"sink",
|
| 908 |
+
"washing machine",
|
| 909 |
+
"water bottle",
|
| 910 |
+
"water jug",
|
| 911 |
+
"water tower",
|
| 912 |
+
"whiskey jug",
|
| 913 |
+
"whistle",
|
| 914 |
+
"hair wig",
|
| 915 |
+
"window screen",
|
| 916 |
+
"window shade",
|
| 917 |
+
"Windsor tie",
|
| 918 |
+
"wine bottle",
|
| 919 |
+
"airplane wing",
|
| 920 |
+
"wok",
|
| 921 |
+
"wooden spoon",
|
| 922 |
+
"wool",
|
| 923 |
+
"split-rail fence",
|
| 924 |
+
"shipwreck",
|
| 925 |
+
"sailboat",
|
| 926 |
+
"yurt",
|
| 927 |
+
"website",
|
| 928 |
+
"comic book",
|
| 929 |
+
"crossword",
|
| 930 |
+
"traffic or street sign",
|
| 931 |
+
"traffic light",
|
| 932 |
+
"dust jacket",
|
| 933 |
+
"menu",
|
| 934 |
+
"plate",
|
| 935 |
+
"guacamole",
|
| 936 |
+
"consomme",
|
| 937 |
+
"hot pot",
|
| 938 |
+
"trifle",
|
| 939 |
+
"ice cream",
|
| 940 |
+
"popsicle",
|
| 941 |
+
"baguette",
|
| 942 |
+
"bagel",
|
| 943 |
+
"pretzel",
|
| 944 |
+
"cheeseburger",
|
| 945 |
+
"hot dog",
|
| 946 |
+
"mashed potatoes",
|
| 947 |
+
"cabbage",
|
| 948 |
+
"broccoli",
|
| 949 |
+
"cauliflower",
|
| 950 |
+
"zucchini",
|
| 951 |
+
"spaghetti squash",
|
| 952 |
+
"acorn squash",
|
| 953 |
+
"butternut squash",
|
| 954 |
+
"cucumber",
|
| 955 |
+
"artichoke",
|
| 956 |
+
"bell pepper",
|
| 957 |
+
"cardoon",
|
| 958 |
+
"mushroom",
|
| 959 |
+
"Granny Smith apple",
|
| 960 |
+
"strawberry",
|
| 961 |
+
"orange",
|
| 962 |
+
"lemon",
|
| 963 |
+
"fig",
|
| 964 |
+
"pineapple",
|
| 965 |
+
"banana",
|
| 966 |
+
"jackfruit",
|
| 967 |
+
"cherimoya (custard apple)",
|
| 968 |
+
"pomegranate",
|
| 969 |
+
"hay",
|
| 970 |
+
"carbonara",
|
| 971 |
+
"chocolate syrup",
|
| 972 |
+
"dough",
|
| 973 |
+
"meatloaf",
|
| 974 |
+
"pizza",
|
| 975 |
+
"pot pie",
|
| 976 |
+
"burrito",
|
| 977 |
+
"red wine",
|
| 978 |
+
"espresso",
|
| 979 |
+
"tea cup",
|
| 980 |
+
"eggnog",
|
| 981 |
+
"mountain",
|
| 982 |
+
"bubble",
|
| 983 |
+
"cliff",
|
| 984 |
+
"coral reef",
|
| 985 |
+
"geyser",
|
| 986 |
+
"lakeshore",
|
| 987 |
+
"promontory",
|
| 988 |
+
"sandbar",
|
| 989 |
+
"beach",
|
| 990 |
+
"valley",
|
| 991 |
+
"volcano",
|
| 992 |
+
"baseball player",
|
| 993 |
+
"bridegroom",
|
| 994 |
+
"scuba diver",
|
| 995 |
+
"rapeseed",
|
| 996 |
+
"daisy",
|
| 997 |
+
"yellow lady's slipper",
|
| 998 |
+
"corn",
|
| 999 |
+
"acorn",
|
| 1000 |
+
"rose hip",
|
| 1001 |
+
"horse chestnut seed",
|
| 1002 |
+
"coral fungus",
|
| 1003 |
+
"agaric",
|
| 1004 |
+
"gyromitra",
|
| 1005 |
+
"stinkhorn mushroom",
|
| 1006 |
+
"earth star fungus",
|
| 1007 |
+
"hen of the woods mushroom",
|
| 1008 |
+
"bolete",
|
| 1009 |
+
"corn cob",
|
| 1010 |
+
"toilet paper",
|
| 1011 |
+
]
|
| 1012 |
+
|
| 1013 |
+
|
| 1014 |
+
openai_imagenet_template = [
|
| 1015 |
+
lambda c: f"a bad photo of a {c}.",
|
| 1016 |
+
lambda c: f"a photo of many {c}.",
|
| 1017 |
+
lambda c: f"a sculpture of a {c}.",
|
| 1018 |
+
lambda c: f"a photo of the hard to see {c}.",
|
| 1019 |
+
lambda c: f"a low resolution photo of the {c}.",
|
| 1020 |
+
lambda c: f"a rendering of a {c}.",
|
| 1021 |
+
lambda c: f"graffiti of a {c}.",
|
| 1022 |
+
lambda c: f"a bad photo of the {c}.",
|
| 1023 |
+
lambda c: f"a cropped photo of the {c}.",
|
| 1024 |
+
lambda c: f"a tattoo of a {c}.",
|
| 1025 |
+
lambda c: f"the embroidered {c}.",
|
| 1026 |
+
lambda c: f"a photo of a hard to see {c}.",
|
| 1027 |
+
lambda c: f"a bright photo of a {c}.",
|
| 1028 |
+
lambda c: f"a photo of a clean {c}.",
|
| 1029 |
+
lambda c: f"a photo of a dirty {c}.",
|
| 1030 |
+
lambda c: f"a dark photo of the {c}.",
|
| 1031 |
+
lambda c: f"a drawing of a {c}.",
|
| 1032 |
+
lambda c: f"a photo of my {c}.",
|
| 1033 |
+
lambda c: f"the plastic {c}.",
|
| 1034 |
+
lambda c: f"a photo of the cool {c}.",
|
| 1035 |
+
lambda c: f"a close-up photo of a {c}.",
|
| 1036 |
+
lambda c: f"a black and white photo of the {c}.",
|
| 1037 |
+
lambda c: f"a painting of the {c}.",
|
| 1038 |
+
lambda c: f"a painting of a {c}.",
|
| 1039 |
+
lambda c: f"a pixelated photo of the {c}.",
|
| 1040 |
+
lambda c: f"a sculpture of the {c}.",
|
| 1041 |
+
lambda c: f"a bright photo of the {c}.",
|
| 1042 |
+
lambda c: f"a cropped photo of a {c}.",
|
| 1043 |
+
lambda c: f"a plastic {c}.",
|
| 1044 |
+
lambda c: f"a photo of the dirty {c}.",
|
| 1045 |
+
lambda c: f"a jpeg corrupted photo of a {c}.",
|
| 1046 |
+
lambda c: f"a blurry photo of the {c}.",
|
| 1047 |
+
lambda c: f"a photo of the {c}.",
|
| 1048 |
+
lambda c: f"a good photo of the {c}.",
|
| 1049 |
+
lambda c: f"a rendering of the {c}.",
|
| 1050 |
+
lambda c: f"a {c} in a video game.",
|
| 1051 |
+
lambda c: f"a photo of one {c}.",
|
| 1052 |
+
lambda c: f"a doodle of a {c}.",
|
| 1053 |
+
lambda c: f"a close-up photo of the {c}.",
|
| 1054 |
+
lambda c: f"a photo of a {c}.",
|
| 1055 |
+
lambda c: f"the origami {c}.",
|
| 1056 |
+
lambda c: f"the {c} in a video game.",
|
| 1057 |
+
lambda c: f"a sketch of a {c}.",
|
| 1058 |
+
lambda c: f"a doodle of the {c}.",
|
| 1059 |
+
lambda c: f"a origami {c}.",
|
| 1060 |
+
lambda c: f"a low resolution photo of a {c}.",
|
| 1061 |
+
lambda c: f"the toy {c}.",
|
| 1062 |
+
lambda c: f"a rendition of the {c}.",
|
| 1063 |
+
lambda c: f"a photo of the clean {c}.",
|
| 1064 |
+
lambda c: f"a photo of a large {c}.",
|
| 1065 |
+
lambda c: f"a rendition of a {c}.",
|
| 1066 |
+
lambda c: f"a photo of a nice {c}.",
|
| 1067 |
+
lambda c: f"a photo of a weird {c}.",
|
| 1068 |
+
lambda c: f"a blurry photo of a {c}.",
|
| 1069 |
+
lambda c: f"a cartoon {c}.",
|
| 1070 |
+
lambda c: f"art of a {c}.",
|
| 1071 |
+
lambda c: f"a sketch of the {c}.",
|
| 1072 |
+
lambda c: f"a embroidered {c}.",
|
| 1073 |
+
lambda c: f"a pixelated photo of a {c}.",
|
| 1074 |
+
lambda c: f"itap of the {c}.",
|
| 1075 |
+
lambda c: f"a jpeg corrupted photo of the {c}.",
|
| 1076 |
+
lambda c: f"a good photo of a {c}.",
|
| 1077 |
+
lambda c: f"a plushie {c}.",
|
| 1078 |
+
lambda c: f"a photo of the nice {c}.",
|
| 1079 |
+
lambda c: f"a photo of the small {c}.",
|
| 1080 |
+
lambda c: f"a photo of the weird {c}.",
|
| 1081 |
+
lambda c: f"the cartoon {c}.",
|
| 1082 |
+
lambda c: f"art of the {c}.",
|
| 1083 |
+
lambda c: f"a drawing of the {c}.",
|
| 1084 |
+
lambda c: f"a photo of the large {c}.",
|
| 1085 |
+
lambda c: f"a black and white photo of a {c}.",
|
| 1086 |
+
lambda c: f"the plushie {c}.",
|
| 1087 |
+
lambda c: f"a dark photo of a {c}.",
|
| 1088 |
+
lambda c: f"itap of a {c}.",
|
| 1089 |
+
lambda c: f"graffiti of the {c}.",
|
| 1090 |
+
lambda c: f"a toy {c}.",
|
| 1091 |
+
lambda c: f"itap of my {c}.",
|
| 1092 |
+
lambda c: f"a photo of a cool {c}.",
|
| 1093 |
+
lambda c: f"a photo of a small {c}.",
|
| 1094 |
+
lambda c: f"a tattoo of the {c}.",
|
| 1095 |
+
]
|
multimodal/examples/flava/data/transforms.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import random
|
| 8 |
+
from functools import partial
|
| 9 |
+
from typing import Any, Callable, Optional
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
from torchmultimodal.transforms.flava_transform import FLAVAImageTransform
|
| 13 |
+
from torchvision import transforms
|
| 14 |
+
from transformers import BertTokenizer
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
|
| 18 |
+
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
|
| 19 |
+
IMAGE_DEFAULT_SIZE = (224, 224)
|
| 20 |
+
VL_MAX_LENGTH_DEFAULT = 77
|
| 21 |
+
TEXT_MAX_LENGTH_DEFAULT = 512
|
| 22 |
+
TEXT_DEFAULT_TOKENIZER = "bert-base-uncased"
|
| 23 |
+
TEXT_WHOLE_WORD_MASK_TOKENIZER = "bert-large-uncased-whole-word-masking"
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def encode_text(text, tokenizer, *args, **kwargs):
|
| 27 |
+
return tokenizer(text, *args, **kwargs)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def encode_text_batch(
|
| 31 |
+
batch, tokenizer, text_columns, return_batch=False, *args, **kwargs
|
| 32 |
+
):
|
| 33 |
+
texts = [batch[column] for column in text_columns]
|
| 34 |
+
tokens = tokenizer(*texts, *args, **kwargs)
|
| 35 |
+
if return_batch:
|
| 36 |
+
batch.update(tokens)
|
| 37 |
+
return batch
|
| 38 |
+
return tokens
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def transform_image_dict(transform, image_dict, *args, **kwargs):
|
| 42 |
+
return {"image": transform(image_dict["image"], *args, **kwargs)}
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def default_torchvision_transforms(
|
| 46 |
+
size=IMAGE_DEFAULT_SIZE,
|
| 47 |
+
mean=IMAGENET_DEFAULT_MEAN,
|
| 48 |
+
std=IMAGENET_DEFAULT_STD,
|
| 49 |
+
use_dict=False,
|
| 50 |
+
):
|
| 51 |
+
transform = transforms.Compose(
|
| 52 |
+
[
|
| 53 |
+
transforms.Resize(size),
|
| 54 |
+
transforms.ToTensor(),
|
| 55 |
+
transforms.Normalize(
|
| 56 |
+
mean=mean,
|
| 57 |
+
std=std,
|
| 58 |
+
),
|
| 59 |
+
]
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
if use_dict:
|
| 63 |
+
transform = partial(transform_image_dict, transform=transform)
|
| 64 |
+
|
| 65 |
+
return transform, transform
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def default_image_pretraining_transforms():
|
| 69 |
+
return FLAVAImageTransform(), FLAVAImageTransform(is_train=False)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def default_text_transform(
|
| 73 |
+
text_tokenizer: Optional[Callable] = None,
|
| 74 |
+
max_text_length: int = TEXT_MAX_LENGTH_DEFAULT,
|
| 75 |
+
**kwargs: Any,
|
| 76 |
+
):
|
| 77 |
+
if text_tokenizer is None:
|
| 78 |
+
text_tokenizer = BertTokenizer.from_pretrained(TEXT_DEFAULT_TOKENIZER)
|
| 79 |
+
|
| 80 |
+
text_transform = partial(
|
| 81 |
+
encode_text,
|
| 82 |
+
tokenizer=text_tokenizer,
|
| 83 |
+
padding="max_length",
|
| 84 |
+
max_length=max_text_length,
|
| 85 |
+
truncation=True,
|
| 86 |
+
return_tensors="pt",
|
| 87 |
+
return_special_tokens_mask=True,
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
return text_transform
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def default_vl_text_transform(
|
| 94 |
+
text_tokenizer: Optional[Callable] = None,
|
| 95 |
+
max_text_length: int = VL_MAX_LENGTH_DEFAULT,
|
| 96 |
+
**kwargs: Any,
|
| 97 |
+
):
|
| 98 |
+
if text_tokenizer is None:
|
| 99 |
+
text_tokenizer = BertTokenizer.from_pretrained(TEXT_WHOLE_WORD_MASK_TOKENIZER)
|
| 100 |
+
return default_text_transform(text_tokenizer, max_text_length=max_text_length)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def pad_batch(batch, batch_size):
|
| 104 |
+
for item in batch.keys():
|
| 105 |
+
if isinstance(batch[item], torch.Tensor):
|
| 106 |
+
diff = batch_size - batch[item].size(0)
|
| 107 |
+
pad = batch[item][-diff:].detach().clone()
|
| 108 |
+
batch[item] = torch.cat([batch[item], pad], dim=0)
|
| 109 |
+
return batch
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
class VLTransform:
|
| 113 |
+
def __init__(self, image_transform, text_transform):
|
| 114 |
+
self.image_transform = image_transform
|
| 115 |
+
self.text_transform = text_transform
|
| 116 |
+
|
| 117 |
+
def __call__(self, info, dataset, itm_probability):
|
| 118 |
+
output = {}
|
| 119 |
+
text = info["text"]
|
| 120 |
+
image = info["image"]
|
| 121 |
+
if itm_probability > 0:
|
| 122 |
+
output["itm_labels"] = torch.ones((1), dtype=torch.long)
|
| 123 |
+
|
| 124 |
+
if random.random() < itm_probability:
|
| 125 |
+
while text == info["text"]:
|
| 126 |
+
text = dataset.select([random.randint(0, len(dataset) - 1)])[0]["text"]
|
| 127 |
+
output["itm_labels"] = torch.zeros((1), dtype=torch.long)
|
| 128 |
+
|
| 129 |
+
output.update(self.image_transform(image))
|
| 130 |
+
output.update(self.text_transform(text))
|
| 131 |
+
return output
|
multimodal/examples/flava/data/utils.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import time
|
| 8 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 9 |
+
from functools import partial
|
| 10 |
+
from typing import List
|
| 11 |
+
|
| 12 |
+
import requests
|
| 13 |
+
from datasets import concatenate_datasets, load_dataset
|
| 14 |
+
from datasets.utils.file_utils import get_datasets_user_agent
|
| 15 |
+
from flava.definitions import HFDatasetInfo
|
| 16 |
+
from PIL import Image, UnidentifiedImageError
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
DATASETS_USER_AGENT = get_datasets_user_agent()
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def build_datasets_from_info(dataset_infos: List[HFDatasetInfo], split: str = "train"):
|
| 23 |
+
dataset_list = []
|
| 24 |
+
for dataset_info in dataset_infos:
|
| 25 |
+
current_dataset = load_dataset(
|
| 26 |
+
dataset_info.key,
|
| 27 |
+
dataset_info.subset,
|
| 28 |
+
split=dataset_info.split_key_mapping[split],
|
| 29 |
+
use_auth_token=True,
|
| 30 |
+
**dataset_info.extra_kwargs,
|
| 31 |
+
)
|
| 32 |
+
if dataset_info.remove_columns is not None:
|
| 33 |
+
current_dataset = current_dataset.remove_columns(
|
| 34 |
+
dataset_info.remove_columns
|
| 35 |
+
)
|
| 36 |
+
if dataset_info.rename_columns is not None:
|
| 37 |
+
for rename in dataset_info.rename_columns:
|
| 38 |
+
current_dataset = current_dataset.rename_column(rename[0], rename[1])
|
| 39 |
+
|
| 40 |
+
dataset_list.append(current_dataset)
|
| 41 |
+
|
| 42 |
+
return concatenate_datasets(dataset_list)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def fetch_single_image(image_url, timeout, retries=0, sleep_timer=0):
|
| 46 |
+
for _ in range(retries + 1):
|
| 47 |
+
try:
|
| 48 |
+
image = Image.open(
|
| 49 |
+
requests.get(
|
| 50 |
+
image_url,
|
| 51 |
+
stream=True,
|
| 52 |
+
headers={"user-agent": DATASETS_USER_AGENT},
|
| 53 |
+
timeout=timeout,
|
| 54 |
+
).raw
|
| 55 |
+
)
|
| 56 |
+
break
|
| 57 |
+
except (requests.exceptions.ConnectionError, UnidentifiedImageError):
|
| 58 |
+
image = None
|
| 59 |
+
time.sleep(sleep_timer)
|
| 60 |
+
|
| 61 |
+
return image
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def fetch_images(batch, num_threads, timeout=None, retries=0, sleep_timer=0):
|
| 65 |
+
if "image" in batch:
|
| 66 |
+
# This dataset already has "image" defined.
|
| 67 |
+
return batch
|
| 68 |
+
with ThreadPoolExecutor(max_workers=num_threads) as executor:
|
| 69 |
+
batch["image"] = list(
|
| 70 |
+
executor.map(
|
| 71 |
+
partial(
|
| 72 |
+
fetch_single_image,
|
| 73 |
+
timeout=timeout,
|
| 74 |
+
retries=retries,
|
| 75 |
+
sleep_timer=sleep_timer,
|
| 76 |
+
),
|
| 77 |
+
batch["image_url"],
|
| 78 |
+
)
|
| 79 |
+
)
|
| 80 |
+
return batch
|
multimodal/examples/flava/native/README.md
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Usage Instructions
|
| 2 |
+
|
| 3 |
+
This is a lightweight native pytorch implementation to run scaling studies on the FLAVA model. The original code is located at: [`examples/flava/train.py`](https://github.com/facebookresearch/multimodal/blob/main/examples/flava/train.py)
|
| 4 |
+
|
| 5 |
+
## Prerequisites
|
| 6 |
+
|
| 7 |
+
- Install torchmultimodal library [from source](https://github.com/facebookresearch/multimodal/blob/main/README.md#building-from-source)
|
| 8 |
+
- `cd multimodal/examples`
|
| 9 |
+
- `pip install -r flava/requirements.txt`
|
| 10 |
+
|
| 11 |
+
## Training
|
| 12 |
+
|
| 13 |
+
### Configuration
|
| 14 |
+
|
| 15 |
+
Configuration presets for various model sizes can be found at: `examples/flava/native/configs`
|
| 16 |
+
|
| 17 |
+
Some config settings that are relevant for scaling: (local) `batch_size`, `activation_checkpointing`, `strategy`.
|
| 18 |
+
|
| 19 |
+
Configs can be overridden through command line, for example: `python -m flava.native.train config=flava/native/configs/pretrain_debug.yaml training.batch_size=8 training.enable_amp=True training.activation_checkpointing=True training.strategy=fsdp`
|
| 20 |
+
|
| 21 |
+
### Running
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
Using [`torchrun`](https://pytorch.org/docs/stable/elastic/run.html):
|
| 25 |
+
|
| 26 |
+
**Single node**
|
| 27 |
+
|
| 28 |
+
`NUM_GPUS=8; torchrun --nproc_per_node=$NUM_GPUS -m flava.native.train config=flava/native/configs/pretrain_debug.yaml`
|
| 29 |
+
|
| 30 |
+
**Multiple nodes (using slurm)**
|
| 31 |
+
|
| 32 |
+
Create a `run.slurm` file:
|
| 33 |
+
|
| 34 |
+
```bash
|
| 35 |
+
RDZV_ENDPOINT=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)
|
| 36 |
+
|
| 37 |
+
srun torchrun --nnodes=$SLURM_NNODES --nproc_per_node=$SLURM_GPUS_PER_TASK --rdzv_id=$SLURM_JOB_ID --rdzv_backend=c10d --rdzv_endpoint=$RDZV_ENDPOINT --max_restarts 0 -m flava.native.train config=flava/native/configs/pretrain_debug.yaml
|
| 38 |
+
$@
|
| 39 |
+
```
|
| 40 |
+
|
| 41 |
+
Run in terminal:
|
| 42 |
+
|
| 43 |
+
`sbatch --partition=[PARTITION] --nodes=[NUM_NODES] --gpus-per-task=[NUM_GPUS_PER_NODE] run.slurm`
|
multimodal/examples/flava/native/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
multimodal/examples/flava/native/configs/1.8b.yaml
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
training:
|
| 2 |
+
strategy: fsdp # can be changed to ddp or fsdp
|
| 3 |
+
seed: 1337
|
| 4 |
+
|
| 5 |
+
batch_size: 8
|
| 6 |
+
num_workers: 4
|
| 7 |
+
prefetch_factor: 3
|
| 8 |
+
|
| 9 |
+
optimizer:
|
| 10 |
+
learning_rate: 1e-3
|
| 11 |
+
adam_eps: 1e-8
|
| 12 |
+
adam_weight_decay: 0.1
|
| 13 |
+
adam_betas: [0.9, 0.999]
|
| 14 |
+
|
| 15 |
+
warmup_steps: 10000
|
| 16 |
+
max_steps: 100000
|
| 17 |
+
|
| 18 |
+
validation_steps: 5000
|
| 19 |
+
log_interval: 10
|
| 20 |
+
|
| 21 |
+
enable_tf32: True
|
| 22 |
+
enable_amp: True
|
| 23 |
+
half_precision_format: "bfloat16" # or float16
|
| 24 |
+
enable_half_reduce_in_fsdp: True # handles the reduction across devices in half precision
|
| 25 |
+
|
| 26 |
+
activation_checkpointing: True
|
| 27 |
+
|
| 28 |
+
datasets:
|
| 29 |
+
_target_: flava.definitions.TrainingDatasetsInfo
|
| 30 |
+
selected:
|
| 31 |
+
- image
|
| 32 |
+
- vl
|
| 33 |
+
- text
|
| 34 |
+
image:
|
| 35 |
+
_target_: flava.definitions.TrainingSingleDatasetInfo
|
| 36 |
+
train:
|
| 37 |
+
- _target_: flava.definitions.HFDatasetInfo
|
| 38 |
+
key: imagenet-1k
|
| 39 |
+
subset: default
|
| 40 |
+
text:
|
| 41 |
+
_target_: flava.definitions.TrainingSingleDatasetInfo
|
| 42 |
+
train:
|
| 43 |
+
- _target_: flava.definitions.HFDatasetInfo
|
| 44 |
+
key: wikitext
|
| 45 |
+
subset: wikitext-103-raw-v1
|
| 46 |
+
datamodule_extra_kwargs:
|
| 47 |
+
text_columns: ["text"]
|
| 48 |
+
vl:
|
| 49 |
+
_target_: flava.definitions.TrainingSingleDatasetInfo
|
| 50 |
+
train:
|
| 51 |
+
- _target_: flava.definitions.HFDatasetInfo
|
| 52 |
+
key: red_caps
|
| 53 |
+
subset: backpacking
|
| 54 |
+
rename_columns:
|
| 55 |
+
- ["caption", "text"]
|
| 56 |
+
val:
|
| 57 |
+
- _target_: flava.definitions.HFDatasetInfo
|
| 58 |
+
key: red_caps
|
| 59 |
+
subset: backpacking
|
| 60 |
+
rename_columns:
|
| 61 |
+
- ["caption", "text"]
|
| 62 |
+
split_key_mapping:
|
| 63 |
+
validation: train
|
| 64 |
+
|
| 65 |
+
model:
|
| 66 |
+
image_num_hidden_layers: 32
|
| 67 |
+
image_hidden_size: 1280
|
| 68 |
+
image_intermediate_size: 5120
|
| 69 |
+
image_num_attention_heads: 16
|
| 70 |
+
|
| 71 |
+
text_num_hidden_layers: 32
|
| 72 |
+
text_hidden_size: 1280
|
| 73 |
+
text_intermediate_size: 5120
|
| 74 |
+
text_num_attention_heads: 16
|
| 75 |
+
|
| 76 |
+
multimodal_num_hidden_layers: 16
|
| 77 |
+
multimodal_hidden_size: 1280
|
| 78 |
+
multimodal_intermediate_size: 5120
|
| 79 |
+
multimodal_num_attention_heads: 16
|
multimodal/examples/flava/native/configs/10b.yaml
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
training:
|
| 2 |
+
strategy: fsdp # can be changed to ddp or fsdp
|
| 3 |
+
seed: 1337
|
| 4 |
+
|
| 5 |
+
batch_size: 8
|
| 6 |
+
num_workers: 4
|
| 7 |
+
prefetch_factor: 3
|
| 8 |
+
|
| 9 |
+
optimizer:
|
| 10 |
+
learning_rate: 1e-3
|
| 11 |
+
adam_eps: 1e-8
|
| 12 |
+
adam_weight_decay: 0.1
|
| 13 |
+
adam_betas: [0.9, 0.999]
|
| 14 |
+
|
| 15 |
+
warmup_steps: 10000
|
| 16 |
+
max_steps: 100000
|
| 17 |
+
|
| 18 |
+
validation_steps: 5000
|
| 19 |
+
log_interval: 10
|
| 20 |
+
|
| 21 |
+
enable_tf32: True
|
| 22 |
+
enable_amp: True
|
| 23 |
+
half_precision_format: "bfloat16" # or float16
|
| 24 |
+
enable_half_reduce_in_fsdp: True # handles the reduction across devices in half precision
|
| 25 |
+
|
| 26 |
+
activation_checkpointing: True
|
| 27 |
+
|
| 28 |
+
datasets:
|
| 29 |
+
_target_: flava.definitions.TrainingDatasetsInfo
|
| 30 |
+
selected:
|
| 31 |
+
- image
|
| 32 |
+
- vl
|
| 33 |
+
- text
|
| 34 |
+
image:
|
| 35 |
+
_target_: flava.definitions.TrainingSingleDatasetInfo
|
| 36 |
+
train:
|
| 37 |
+
- _target_: flava.definitions.HFDatasetInfo
|
| 38 |
+
key: imagenet-1k
|
| 39 |
+
subset: default
|
| 40 |
+
text:
|
| 41 |
+
_target_: flava.definitions.TrainingSingleDatasetInfo
|
| 42 |
+
train:
|
| 43 |
+
- _target_: flava.definitions.HFDatasetInfo
|
| 44 |
+
key: wikitext
|
| 45 |
+
subset: wikitext-103-raw-v1
|
| 46 |
+
datamodule_extra_kwargs:
|
| 47 |
+
text_columns: ["text"]
|
| 48 |
+
vl:
|
| 49 |
+
_target_: flava.definitions.TrainingSingleDatasetInfo
|
| 50 |
+
train:
|
| 51 |
+
- _target_: flava.definitions.HFDatasetInfo
|
| 52 |
+
key: red_caps
|
| 53 |
+
subset: backpacking
|
| 54 |
+
rename_columns:
|
| 55 |
+
- ["caption", "text"]
|
| 56 |
+
val:
|
| 57 |
+
- _target_: flava.definitions.HFDatasetInfo
|
| 58 |
+
key: red_caps
|
| 59 |
+
subset: backpacking
|
| 60 |
+
rename_columns:
|
| 61 |
+
- ["caption", "text"]
|
| 62 |
+
split_key_mapping:
|
| 63 |
+
validation: train
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
model:
|
| 67 |
+
image_num_hidden_layers: 64
|
| 68 |
+
image_hidden_size: 2048
|
| 69 |
+
image_intermediate_size: 10240
|
| 70 |
+
image_num_attention_heads: 16
|
| 71 |
+
|
| 72 |
+
text_num_hidden_layers: 64
|
| 73 |
+
text_hidden_size: 2048
|
| 74 |
+
text_intermediate_size: 10240
|
| 75 |
+
text_num_attention_heads: 16
|
| 76 |
+
|
| 77 |
+
multimodal_num_hidden_layers: 40
|
| 78 |
+
multimodal_hidden_size: 2048
|
| 79 |
+
multimodal_intermediate_size: 10240
|
| 80 |
+
multimodal_num_attention_heads: 16
|
multimodal/examples/flava/native/configs/2.7b.yaml
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
training:
|
| 2 |
+
strategy: fsdp # can be changed to ddp or fsdp
|
| 3 |
+
seed: 1337
|
| 4 |
+
|
| 5 |
+
batch_size: 8
|
| 6 |
+
num_workers: 4
|
| 7 |
+
prefetch_factor: 3
|
| 8 |
+
|
| 9 |
+
optimizer:
|
| 10 |
+
learning_rate: 1e-3
|
| 11 |
+
adam_eps: 1e-8
|
| 12 |
+
adam_weight_decay: 0.1
|
| 13 |
+
adam_betas: [0.9, 0.999]
|
| 14 |
+
|
| 15 |
+
warmup_steps: 10000
|
| 16 |
+
max_steps: 100000
|
| 17 |
+
|
| 18 |
+
validation_steps: 5000
|
| 19 |
+
log_interval: 10
|
| 20 |
+
|
| 21 |
+
enable_tf32: True
|
| 22 |
+
enable_amp: True
|
| 23 |
+
half_precision_format: "bfloat16" # or float16
|
| 24 |
+
enable_half_reduce_in_fsdp: True # handles the reduction across devices in half precision
|
| 25 |
+
|
| 26 |
+
activation_checkpointing: True
|
| 27 |
+
|
| 28 |
+
datasets:
|
| 29 |
+
_target_: flava.definitions.TrainingDatasetsInfo
|
| 30 |
+
selected:
|
| 31 |
+
- image
|
| 32 |
+
- vl
|
| 33 |
+
- text
|
| 34 |
+
image:
|
| 35 |
+
_target_: flava.definitions.TrainingSingleDatasetInfo
|
| 36 |
+
train:
|
| 37 |
+
- _target_: flava.definitions.HFDatasetInfo
|
| 38 |
+
key: imagenet-1k
|
| 39 |
+
subset: default
|
| 40 |
+
text:
|
| 41 |
+
_target_: flava.definitions.TrainingSingleDatasetInfo
|
| 42 |
+
train:
|
| 43 |
+
- _target_: flava.definitions.HFDatasetInfo
|
| 44 |
+
key: wikitext
|
| 45 |
+
subset: wikitext-103-raw-v1
|
| 46 |
+
datamodule_extra_kwargs:
|
| 47 |
+
text_columns: ["text"]
|
| 48 |
+
vl:
|
| 49 |
+
_target_: flava.definitions.TrainingSingleDatasetInfo
|
| 50 |
+
train:
|
| 51 |
+
- _target_: flava.definitions.HFDatasetInfo
|
| 52 |
+
key: red_caps
|
| 53 |
+
subset: backpacking
|
| 54 |
+
rename_columns:
|
| 55 |
+
- ["caption", "text"]
|
| 56 |
+
val:
|
| 57 |
+
- _target_: flava.definitions.HFDatasetInfo
|
| 58 |
+
key: red_caps
|
| 59 |
+
subset: backpacking
|
| 60 |
+
rename_columns:
|
| 61 |
+
- ["caption", "text"]
|
| 62 |
+
split_key_mapping:
|
| 63 |
+
validation: train
|
| 64 |
+
|
| 65 |
+
model:
|
| 66 |
+
image_num_hidden_layers: 40
|
| 67 |
+
image_hidden_size: 1408
|
| 68 |
+
image_intermediate_size: 6144
|
| 69 |
+
image_num_attention_heads: 16
|
| 70 |
+
|
| 71 |
+
text_num_hidden_layers: 40
|
| 72 |
+
text_hidden_size: 1408
|
| 73 |
+
text_intermediate_size: 6144
|
| 74 |
+
text_num_attention_heads: 16
|
| 75 |
+
|
| 76 |
+
multimodal_num_hidden_layers: 20
|
| 77 |
+
multimodal_hidden_size: 1408
|
| 78 |
+
multimodal_intermediate_size: 6144
|
| 79 |
+
multimodal_num_attention_heads: 16
|
multimodal/examples/flava/native/configs/4.8b.yaml
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
training:
|
| 2 |
+
strategy: fsdp # can be changed to ddp or fsdp
|
| 3 |
+
seed: 1337
|
| 4 |
+
|
| 5 |
+
batch_size: 12
|
| 6 |
+
num_workers: 4
|
| 7 |
+
prefetch_factor: 3
|
| 8 |
+
|
| 9 |
+
optimizer:
|
| 10 |
+
learning_rate: 1e-3
|
| 11 |
+
adam_eps: 1e-8
|
| 12 |
+
adam_weight_decay: 0.1
|
| 13 |
+
adam_betas: [0.9, 0.999]
|
| 14 |
+
|
| 15 |
+
warmup_steps: 10000
|
| 16 |
+
max_steps: 100000
|
| 17 |
+
|
| 18 |
+
validation_steps: 5000
|
| 19 |
+
log_interval: 10
|
| 20 |
+
|
| 21 |
+
enable_tf32: True
|
| 22 |
+
enable_amp: True
|
| 23 |
+
half_precision_format: "bfloat16" # or float16
|
| 24 |
+
enable_half_reduce_in_fsdp: True # handles the reduction across devices in half precision
|
| 25 |
+
|
| 26 |
+
activation_checkpointing: True
|
| 27 |
+
|
| 28 |
+
datasets:
|
| 29 |
+
_target_: flava.definitions.TrainingDatasetsInfo
|
| 30 |
+
selected:
|
| 31 |
+
- image
|
| 32 |
+
- vl
|
| 33 |
+
- text
|
| 34 |
+
image:
|
| 35 |
+
_target_: flava.definitions.TrainingSingleDatasetInfo
|
| 36 |
+
train:
|
| 37 |
+
- _target_: flava.definitions.HFDatasetInfo
|
| 38 |
+
key: imagenet-1k
|
| 39 |
+
subset: default
|
| 40 |
+
text:
|
| 41 |
+
_target_: flava.definitions.TrainingSingleDatasetInfo
|
| 42 |
+
train:
|
| 43 |
+
- _target_: flava.definitions.HFDatasetInfo
|
| 44 |
+
key: wikitext
|
| 45 |
+
subset: wikitext-103-raw-v1
|
| 46 |
+
datamodule_extra_kwargs:
|
| 47 |
+
text_columns: ["text"]
|
| 48 |
+
vl:
|
| 49 |
+
_target_: flava.definitions.TrainingSingleDatasetInfo
|
| 50 |
+
train:
|
| 51 |
+
- _target_: flava.definitions.HFDatasetInfo
|
| 52 |
+
key: red_caps
|
| 53 |
+
subset: backpacking
|
| 54 |
+
rename_columns:
|
| 55 |
+
- ["caption", "text"]
|
| 56 |
+
val:
|
| 57 |
+
- _target_: flava.definitions.HFDatasetInfo
|
| 58 |
+
key: red_caps
|
| 59 |
+
subset: backpacking
|
| 60 |
+
rename_columns:
|
| 61 |
+
- ["caption", "text"]
|
| 62 |
+
split_key_mapping:
|
| 63 |
+
validation: train
|
| 64 |
+
|
| 65 |
+
model:
|
| 66 |
+
image_num_hidden_layers: 48
|
| 67 |
+
image_hidden_size: 1664
|
| 68 |
+
image_intermediate_size: 8192
|
| 69 |
+
image_num_attention_heads: 16
|
| 70 |
+
|
| 71 |
+
text_num_hidden_layers: 48
|
| 72 |
+
text_hidden_size: 1664
|
| 73 |
+
text_intermediate_size: 8192
|
| 74 |
+
text_num_attention_heads: 16
|
| 75 |
+
|
| 76 |
+
multimodal_num_hidden_layers: 24
|
| 77 |
+
multimodal_hidden_size: 1664
|
| 78 |
+
multimodal_intermediate_size: 8192
|
| 79 |
+
multimodal_num_attention_heads: 16
|
multimodal/examples/flava/native/configs/900m.yaml
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
training:
|
| 2 |
+
strategy: ddp # can be changed to ddp or fsdp
|
| 3 |
+
seed: 1337
|
| 4 |
+
|
| 5 |
+
batch_size: 8
|
| 6 |
+
num_workers: 4
|
| 7 |
+
prefetch_factor: 3
|
| 8 |
+
|
| 9 |
+
optimizer:
|
| 10 |
+
learning_rate: 1e-3
|
| 11 |
+
adam_eps: 1e-8
|
| 12 |
+
adam_weight_decay: 0.1
|
| 13 |
+
adam_betas: [0.9, 0.999]
|
| 14 |
+
|
| 15 |
+
warmup_steps: 10000
|
| 16 |
+
max_steps: 100000
|
| 17 |
+
|
| 18 |
+
validation_steps: 5000
|
| 19 |
+
log_interval: 10
|
| 20 |
+
|
| 21 |
+
enable_tf32: True
|
| 22 |
+
enable_amp: True
|
| 23 |
+
half_precision_format: "bfloat16" # or float16
|
| 24 |
+
enable_half_reduce_in_fsdp: True # handles the reduction across devices in half precision
|
| 25 |
+
|
| 26 |
+
activation_checkpointing: True
|
| 27 |
+
|
| 28 |
+
datasets:
|
| 29 |
+
_target_: flava.definitions.TrainingDatasetsInfo
|
| 30 |
+
selected:
|
| 31 |
+
- image
|
| 32 |
+
- vl
|
| 33 |
+
- text
|
| 34 |
+
image:
|
| 35 |
+
_target_: flava.definitions.TrainingSingleDatasetInfo
|
| 36 |
+
train:
|
| 37 |
+
- _target_: flava.definitions.HFDatasetInfo
|
| 38 |
+
key: imagenet-1k
|
| 39 |
+
subset: default
|
| 40 |
+
text:
|
| 41 |
+
_target_: flava.definitions.TrainingSingleDatasetInfo
|
| 42 |
+
train:
|
| 43 |
+
- _target_: flava.definitions.HFDatasetInfo
|
| 44 |
+
key: wikitext
|
| 45 |
+
subset: wikitext-103-raw-v1
|
| 46 |
+
datamodule_extra_kwargs:
|
| 47 |
+
text_columns: ["text"]
|
| 48 |
+
vl:
|
| 49 |
+
_target_: flava.definitions.TrainingSingleDatasetInfo
|
| 50 |
+
train:
|
| 51 |
+
- _target_: flava.definitions.HFDatasetInfo
|
| 52 |
+
key: red_caps
|
| 53 |
+
subset: backpacking
|
| 54 |
+
rename_columns:
|
| 55 |
+
- ["caption", "text"]
|
| 56 |
+
val:
|
| 57 |
+
- _target_: flava.definitions.HFDatasetInfo
|
| 58 |
+
key: red_caps
|
| 59 |
+
subset: backpacking
|
| 60 |
+
rename_columns:
|
| 61 |
+
- ["caption", "text"]
|
| 62 |
+
split_key_mapping:
|
| 63 |
+
validation: train
|
| 64 |
+
|
| 65 |
+
model:
|
| 66 |
+
image_num_hidden_layers: 24
|
| 67 |
+
image_hidden_size: 1024
|
| 68 |
+
image_intermediate_size: 4096
|
| 69 |
+
image_num_attention_heads: 16
|
| 70 |
+
|
| 71 |
+
text_num_hidden_layers: 24
|
| 72 |
+
text_hidden_size: 1024
|
| 73 |
+
text_intermediate_size: 4096
|
| 74 |
+
text_num_attention_heads: 16
|
| 75 |
+
|
| 76 |
+
multimodal_num_hidden_layers: 12
|
| 77 |
+
multimodal_hidden_size: 1024
|
| 78 |
+
multimodal_intermediate_size: 4096
|
| 79 |
+
multimodal_num_attention_heads: 16
|
multimodal/examples/flava/native/configs/pretrain_debug.yaml
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
training:
|
| 2 |
+
strategy: ddp # can be changed to ddp or fsdp
|
| 3 |
+
seed: 1337
|
| 4 |
+
|
| 5 |
+
batch_size: 8
|
| 6 |
+
num_workers: 4
|
| 7 |
+
prefetch_factor: 3
|
| 8 |
+
|
| 9 |
+
optimizer:
|
| 10 |
+
learning_rate: 1e-3
|
| 11 |
+
adam_eps: 1e-8
|
| 12 |
+
adam_weight_decay: 0.1
|
| 13 |
+
adam_betas: [0.9, 0.999]
|
| 14 |
+
|
| 15 |
+
warmup_steps: 10000
|
| 16 |
+
max_steps: 100000
|
| 17 |
+
|
| 18 |
+
validation_steps: 5000
|
| 19 |
+
log_interval: 10
|
| 20 |
+
|
| 21 |
+
enable_tf32: True
|
| 22 |
+
enable_amp: True
|
| 23 |
+
half_precision_format: "bfloat16" # or float16
|
| 24 |
+
enable_half_reduce_in_fsdp: True # handles the reduction across devices in half precision
|
| 25 |
+
|
| 26 |
+
activation_checkpointing: False
|
| 27 |
+
|
| 28 |
+
datasets:
|
| 29 |
+
_target_: flava.definitions.TrainingDatasetsInfo
|
| 30 |
+
selected:
|
| 31 |
+
- image
|
| 32 |
+
- vl
|
| 33 |
+
- text
|
| 34 |
+
image:
|
| 35 |
+
_target_: flava.definitions.TrainingSingleDatasetInfo
|
| 36 |
+
train:
|
| 37 |
+
- _target_: flava.definitions.HFDatasetInfo
|
| 38 |
+
key: imagenet-1k
|
| 39 |
+
subset: default
|
| 40 |
+
text:
|
| 41 |
+
_target_: flava.definitions.TrainingSingleDatasetInfo
|
| 42 |
+
train:
|
| 43 |
+
- _target_: flava.definitions.HFDatasetInfo
|
| 44 |
+
key: wikitext
|
| 45 |
+
subset: wikitext-103-raw-v1
|
| 46 |
+
datamodule_extra_kwargs:
|
| 47 |
+
text_columns: ["text"]
|
| 48 |
+
vl:
|
| 49 |
+
_target_: flava.definitions.TrainingSingleDatasetInfo
|
| 50 |
+
train:
|
| 51 |
+
- _target_: flava.definitions.HFDatasetInfo
|
| 52 |
+
key: red_caps
|
| 53 |
+
subset: backpacking
|
| 54 |
+
rename_columns:
|
| 55 |
+
- ["caption", "text"]
|
| 56 |
+
val:
|
| 57 |
+
- _target_: flava.definitions.HFDatasetInfo
|
| 58 |
+
key: red_caps
|
| 59 |
+
subset: backpacking
|
| 60 |
+
rename_columns:
|
| 61 |
+
- ["caption", "text"]
|
| 62 |
+
split_key_mapping:
|
| 63 |
+
validation: train
|
multimodal/examples/flava/native/data.py
ADDED
|
@@ -0,0 +1,560 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
from functools import partial
|
| 9 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.distributed as dist
|
| 13 |
+
import torchvision
|
| 14 |
+
|
| 15 |
+
from flava.data.transforms import (
|
| 16 |
+
default_image_pretraining_transforms,
|
| 17 |
+
default_text_transform,
|
| 18 |
+
default_torchvision_transforms,
|
| 19 |
+
encode_text_batch,
|
| 20 |
+
pad_batch,
|
| 21 |
+
TEXT_DEFAULT_TOKENIZER,
|
| 22 |
+
TEXT_WHOLE_WORD_MASK_TOKENIZER,
|
| 23 |
+
VL_MAX_LENGTH_DEFAULT,
|
| 24 |
+
VLTransform,
|
| 25 |
+
)
|
| 26 |
+
from flava.data.utils import build_datasets_from_info, fetch_images
|
| 27 |
+
from flava.definitions import HFDatasetInfo, TorchVisionDatasetInfo
|
| 28 |
+
from pytorch_lightning import LightningDataModule
|
| 29 |
+
from torch.utils.data.distributed import DistributedSampler
|
| 30 |
+
from transformers import (
|
| 31 |
+
BertTokenizer,
|
| 32 |
+
DataCollatorForLanguageModeling,
|
| 33 |
+
DataCollatorForWholeWordMask,
|
| 34 |
+
DefaultDataCollator,
|
| 35 |
+
TRANSFORMERS_CACHE,
|
| 36 |
+
)
|
| 37 |
+
from transformers.data.data_collator import torch_default_data_collator
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def transform_image(transform, sample):
|
| 41 |
+
sample.update(transform(sample["image"]))
|
| 42 |
+
return sample
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def get_sampler(dataset, shuffle=True):
|
| 46 |
+
if dist.is_initialized():
|
| 47 |
+
return DistributedSampler(dataset, shuffle=shuffle)
|
| 48 |
+
if shuffle:
|
| 49 |
+
return torch.utils.data.RandomSampler(dataset)
|
| 50 |
+
return torch.utils.data.SequentialSampler(dataset)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class DataCollatorForWholeWordMaskRetainingBatch(DataCollatorForWholeWordMask):
|
| 54 |
+
def torch_call(
|
| 55 |
+
self, examples: List[Union[List[int], Any, Dict[str, Any]]]
|
| 56 |
+
) -> Dict[str, Any]:
|
| 57 |
+
masked_batch = super().torch_call(examples)
|
| 58 |
+
examples = torch_default_data_collator(examples)
|
| 59 |
+
examples["input_ids"] = masked_batch["input_ids"]
|
| 60 |
+
examples["labels"] = masked_batch["labels"]
|
| 61 |
+
return examples
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class ImageDataModule(LightningDataModule):
|
| 65 |
+
def __init__(
|
| 66 |
+
self,
|
| 67 |
+
train_infos: List[HFDatasetInfo],
|
| 68 |
+
val_infos: Optional[List[HFDatasetInfo]] = None,
|
| 69 |
+
transforms: Optional[Tuple[Callable, Callable]] = None,
|
| 70 |
+
batch_size: int = 32,
|
| 71 |
+
num_workers: int = 4,
|
| 72 |
+
allow_uneven_batches: bool = False,
|
| 73 |
+
prefetch_factor: int = 2,
|
| 74 |
+
**kwargs: Any,
|
| 75 |
+
):
|
| 76 |
+
super().__init__()
|
| 77 |
+
self.train_dataset_infos = train_infos
|
| 78 |
+
self.val_dataset_infos = val_infos
|
| 79 |
+
if self.val_dataset_infos is None:
|
| 80 |
+
self.val_dataset_infos = train_infos
|
| 81 |
+
|
| 82 |
+
self.batch_size = batch_size
|
| 83 |
+
self.num_workers = num_workers
|
| 84 |
+
self.allow_uneven_batches = allow_uneven_batches
|
| 85 |
+
self.prefetch_factor = prefetch_factor
|
| 86 |
+
|
| 87 |
+
if transforms is None:
|
| 88 |
+
transforms = default_image_pretraining_transforms()
|
| 89 |
+
|
| 90 |
+
self.train_transform, self.test_transform = transforms
|
| 91 |
+
|
| 92 |
+
def setup(self, stage=None):
|
| 93 |
+
train_transform = partial(transform_image, self.train_transform)
|
| 94 |
+
val_transform = partial(transform_image, self.test_transform)
|
| 95 |
+
|
| 96 |
+
self.train_dataset = build_datasets_from_info(
|
| 97 |
+
self.train_dataset_infos, split="train"
|
| 98 |
+
)
|
| 99 |
+
self.train_dataset.set_transform(train_transform)
|
| 100 |
+
self.val_dataset = build_datasets_from_info(
|
| 101 |
+
self.val_dataset_infos, split="validation"
|
| 102 |
+
)
|
| 103 |
+
self.val_dataset.set_transform(val_transform)
|
| 104 |
+
|
| 105 |
+
def train_dataloader(self):
|
| 106 |
+
return torch.utils.data.DataLoader(
|
| 107 |
+
self.train_dataset,
|
| 108 |
+
batch_size=self.batch_size,
|
| 109 |
+
num_workers=self.num_workers,
|
| 110 |
+
sampler=get_sampler(self.train_dataset, shuffle=True),
|
| 111 |
+
pin_memory=True,
|
| 112 |
+
persistent_workers=True,
|
| 113 |
+
prefetch_factor=self.prefetch_factor,
|
| 114 |
+
# uneven batches can cause distributed issues,
|
| 115 |
+
# drop last batch to prevent those.
|
| 116 |
+
# ideally, we don't need to drop these for unimodal cases
|
| 117 |
+
# but just to be safe
|
| 118 |
+
drop_last=True,
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
def val_dataloader(self):
|
| 122 |
+
return torch.utils.data.DataLoader(
|
| 123 |
+
self.val_dataset,
|
| 124 |
+
batch_size=self.batch_size,
|
| 125 |
+
num_workers=self.num_workers,
|
| 126 |
+
sampler=get_sampler(self.val_dataset, shuffle=False),
|
| 127 |
+
pin_memory=True,
|
| 128 |
+
persistent_workers=True,
|
| 129 |
+
prefetch_factor=self.prefetch_factor,
|
| 130 |
+
# uneven batches can cause distributed issues,
|
| 131 |
+
# drop last batch to prevent those.
|
| 132 |
+
# ideally, we don't need to drop these for unimodal cases
|
| 133 |
+
# but just to be safe
|
| 134 |
+
drop_last=True,
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
def test_dataloader(self):
|
| 138 |
+
return self.val_dataloader()
|
| 139 |
+
|
| 140 |
+
def on_before_batch_transfer(self, batch, *args):
|
| 141 |
+
if batch["label"].size(0) < self.batch_size and not self.allow_uneven_batches:
|
| 142 |
+
batch = pad_batch(batch, self.batch_size)
|
| 143 |
+
return batch
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
class TextDataModule(LightningDataModule):
|
| 147 |
+
def __init__(
|
| 148 |
+
self,
|
| 149 |
+
train_infos: List[HFDatasetInfo],
|
| 150 |
+
text_columns: List[str],
|
| 151 |
+
val_infos: Optional[List[HFDatasetInfo]] = None,
|
| 152 |
+
tokenizer: Optional[Callable] = None,
|
| 153 |
+
max_length: int = 512,
|
| 154 |
+
batch_size: int = 32,
|
| 155 |
+
num_workers: int = 4,
|
| 156 |
+
allow_uneven_batches: bool = False,
|
| 157 |
+
prefetch_factor: int = 2,
|
| 158 |
+
**kwargs: Any,
|
| 159 |
+
):
|
| 160 |
+
super().__init__()
|
| 161 |
+
self.train_dataset_infos = train_infos
|
| 162 |
+
self.text_columns = text_columns
|
| 163 |
+
self.val_dataset_infos = val_infos
|
| 164 |
+
if self.val_dataset_infos is None:
|
| 165 |
+
self.val_dataset_infos = train_infos
|
| 166 |
+
self.tokenizer = tokenizer
|
| 167 |
+
self.max_length = max_length
|
| 168 |
+
self.batch_size = batch_size
|
| 169 |
+
self.num_workers = num_workers
|
| 170 |
+
self.allow_uneven_batches = allow_uneven_batches
|
| 171 |
+
self.prefetch_factor = prefetch_factor
|
| 172 |
+
|
| 173 |
+
def setup(self, stage=None):
|
| 174 |
+
if self.tokenizer is None:
|
| 175 |
+
self.tokenizer = BertTokenizer.from_pretrained(TEXT_DEFAULT_TOKENIZER)
|
| 176 |
+
transform = partial(
|
| 177 |
+
encode_text_batch,
|
| 178 |
+
tokenizer=self.tokenizer,
|
| 179 |
+
padding="max_length",
|
| 180 |
+
max_length=self.max_length,
|
| 181 |
+
truncation=True,
|
| 182 |
+
return_tensors="pt",
|
| 183 |
+
return_special_tokens_mask=True,
|
| 184 |
+
text_columns=self.text_columns,
|
| 185 |
+
return_batch=True,
|
| 186 |
+
)
|
| 187 |
+
self.train_dataset = build_datasets_from_info(
|
| 188 |
+
self.train_dataset_infos, split="train"
|
| 189 |
+
)
|
| 190 |
+
self.train_dataset.set_transform(transform)
|
| 191 |
+
self.val_dataset = build_datasets_from_info(
|
| 192 |
+
self.val_dataset_infos, split="validation"
|
| 193 |
+
)
|
| 194 |
+
self.val_dataset.set_transform(transform)
|
| 195 |
+
|
| 196 |
+
def train_dataloader(self):
|
| 197 |
+
return self._build_dataloader(self.train_dataset)
|
| 198 |
+
|
| 199 |
+
def val_dataloader(self):
|
| 200 |
+
return self._build_dataloader(self.val_dataset, shuffle=False)
|
| 201 |
+
|
| 202 |
+
def _build_dataloader(self, dataset, drop_last=False, shuffle=True):
|
| 203 |
+
return torch.utils.data.DataLoader(
|
| 204 |
+
dataset,
|
| 205 |
+
batch_size=self.batch_size,
|
| 206 |
+
num_workers=self.num_workers,
|
| 207 |
+
sampler=get_sampler(dataset, shuffle),
|
| 208 |
+
pin_memory=True,
|
| 209 |
+
persistent_workers=True,
|
| 210 |
+
prefetch_factor=self.prefetch_factor,
|
| 211 |
+
collate_fn=self._build_collator(),
|
| 212 |
+
drop_last=drop_last,
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
def _build_collator(self):
|
| 216 |
+
return DefaultDataCollator()
|
| 217 |
+
|
| 218 |
+
def on_before_batch_transfer(self, batch, *args):
|
| 219 |
+
batch.pop("token_type_ids", None)
|
| 220 |
+
mask = batch.pop("attention_mask", None)
|
| 221 |
+
if mask.size(0) < self.batch_size and not self.allow_uneven_batches:
|
| 222 |
+
batch = pad_batch(batch, self.batch_size)
|
| 223 |
+
return batch
|
| 224 |
+
|
| 225 |
+
def on_after_batch_transfer(self, batch, *args):
|
| 226 |
+
batch["text"] = batch.pop("input_ids")
|
| 227 |
+
return batch
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
class MLMDataModule(TextDataModule):
|
| 231 |
+
def __init__(
|
| 232 |
+
self,
|
| 233 |
+
train_infos: List[HFDatasetInfo],
|
| 234 |
+
text_columns: List[str],
|
| 235 |
+
val_infos: Optional[List[HFDatasetInfo]] = None,
|
| 236 |
+
mlm_probability: float = 0.15,
|
| 237 |
+
ignore_index: int = -1,
|
| 238 |
+
**kwargs: Any,
|
| 239 |
+
):
|
| 240 |
+
super().__init__(train_infos, text_columns, val_infos, **kwargs)
|
| 241 |
+
self.mlm_probability = mlm_probability
|
| 242 |
+
self.ignore_index = ignore_index
|
| 243 |
+
|
| 244 |
+
def setup(self, stage=None):
|
| 245 |
+
if self.tokenizer is None:
|
| 246 |
+
self.tokenizer = BertTokenizer.from_pretrained(TEXT_DEFAULT_TOKENIZER)
|
| 247 |
+
transform = partial(
|
| 248 |
+
encode_text_batch,
|
| 249 |
+
tokenizer=self.tokenizer,
|
| 250 |
+
padding="max_length",
|
| 251 |
+
max_length=self.max_length,
|
| 252 |
+
truncation=True,
|
| 253 |
+
return_tensors="pt",
|
| 254 |
+
return_special_tokens_mask=True,
|
| 255 |
+
text_columns=self.text_columns,
|
| 256 |
+
return_batch=False,
|
| 257 |
+
)
|
| 258 |
+
self.train_dataset = build_datasets_from_info(
|
| 259 |
+
self.train_dataset_infos, split="train"
|
| 260 |
+
)
|
| 261 |
+
self.train_dataset.set_transform(transform)
|
| 262 |
+
self.val_dataset = build_datasets_from_info(
|
| 263 |
+
self.val_dataset_infos, split="validation"
|
| 264 |
+
)
|
| 265 |
+
self.val_dataset.set_transform(transform)
|
| 266 |
+
|
| 267 |
+
def _build_dataloader(self, dataset, drop_last=True, shuffle=True):
|
| 268 |
+
# uneven batches can cause distributed issues,
|
| 269 |
+
# drop last batch to prevent those.
|
| 270 |
+
# ideally, we don't need to drop these for unimodal cases
|
| 271 |
+
# but just to be safe
|
| 272 |
+
return super()._build_dataloader(dataset, drop_last=drop_last, shuffle=shuffle)
|
| 273 |
+
|
| 274 |
+
def _build_collator(self):
|
| 275 |
+
return DataCollatorForLanguageModeling(
|
| 276 |
+
self.tokenizer, mlm_probability=self.mlm_probability
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
def on_after_batch_transfer(self, batch, *args):
|
| 280 |
+
batch["text_masked"] = batch.pop("input_ids")
|
| 281 |
+
batch["mlm_labels"] = batch.pop("labels")
|
| 282 |
+
batch["mlm_labels"][batch["mlm_labels"] == -100] = self.ignore_index
|
| 283 |
+
return batch
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
class VLDataModule(LightningDataModule):
|
| 287 |
+
def __init__(
|
| 288 |
+
self,
|
| 289 |
+
train_infos: List[HFDatasetInfo],
|
| 290 |
+
val_infos: List[HFDatasetInfo],
|
| 291 |
+
text_transform: Optional[Callable] = None,
|
| 292 |
+
image_transforms: Optional[Tuple[Callable, Callable]] = None,
|
| 293 |
+
mlm_probablity: float = 0.15,
|
| 294 |
+
batch_size: int = 32,
|
| 295 |
+
num_workers: int = 4,
|
| 296 |
+
finetuning: bool = False,
|
| 297 |
+
ignore_index: int = -1,
|
| 298 |
+
itm_probability: float = 0.1,
|
| 299 |
+
allow_uneven_batches: bool = False,
|
| 300 |
+
fetch_num_threads: int = 4,
|
| 301 |
+
fetch_retries: int = 0,
|
| 302 |
+
fetch_sleep_timer: int = 0,
|
| 303 |
+
fetch_timeout: Optional[float] = None,
|
| 304 |
+
fetch_batch_size: int = 50,
|
| 305 |
+
prefetch_factor=2,
|
| 306 |
+
**kwargs,
|
| 307 |
+
):
|
| 308 |
+
super().__init__()
|
| 309 |
+
|
| 310 |
+
self.train_dataset_infos = train_infos
|
| 311 |
+
self.val_dataset_infos = val_infos
|
| 312 |
+
if self.val_dataset_infos is None:
|
| 313 |
+
self.val_dataset_infos = train_infos
|
| 314 |
+
if image_transforms is None:
|
| 315 |
+
if not finetuning:
|
| 316 |
+
image_transforms = default_image_pretraining_transforms()
|
| 317 |
+
else:
|
| 318 |
+
image_transforms = default_torchvision_transforms(use_dict=True)
|
| 319 |
+
|
| 320 |
+
self.train_image_transform, self.test_image_transform = image_transforms
|
| 321 |
+
self.text_transform = text_transform
|
| 322 |
+
self.mlm_probability = mlm_probablity
|
| 323 |
+
self.batch_size = batch_size
|
| 324 |
+
self.num_workers = num_workers
|
| 325 |
+
self.ignore_index = ignore_index
|
| 326 |
+
self.itm_probability = itm_probability
|
| 327 |
+
self.allow_uneven_batches = allow_uneven_batches
|
| 328 |
+
self.fetch_num_threads = fetch_num_threads
|
| 329 |
+
self.fetch_retries = fetch_retries
|
| 330 |
+
self.fetch_sleep_timer = fetch_sleep_timer
|
| 331 |
+
self.fetch_timeout = fetch_timeout
|
| 332 |
+
self.fetch_batch_size = fetch_batch_size
|
| 333 |
+
self.prefetch_factor = prefetch_factor
|
| 334 |
+
|
| 335 |
+
def setup(self, stage=None):
|
| 336 |
+
if self.text_transform is None:
|
| 337 |
+
# TODO Update to use whole word mask vocab
|
| 338 |
+
text_tokenizer = BertTokenizer.from_pretrained(
|
| 339 |
+
TEXT_WHOLE_WORD_MASK_TOKENIZER
|
| 340 |
+
)
|
| 341 |
+
self.text_transform = default_text_transform(
|
| 342 |
+
text_tokenizer, max_text_length=VL_MAX_LENGTH_DEFAULT
|
| 343 |
+
)
|
| 344 |
+
self.text_tokenizer = self.text_transform.keywords["tokenizer"]
|
| 345 |
+
train_vl_transform = VLTransform(
|
| 346 |
+
self.train_image_transform, self.text_transform
|
| 347 |
+
)
|
| 348 |
+
val_vl_transform = VLTransform(self.test_image_transform, self.text_transform)
|
| 349 |
+
|
| 350 |
+
train_dataset = build_datasets_from_info(
|
| 351 |
+
self.train_dataset_infos, split="train"
|
| 352 |
+
)
|
| 353 |
+
train_dataset = train_dataset.map(
|
| 354 |
+
fetch_images,
|
| 355 |
+
batched=True,
|
| 356 |
+
batch_size=self.fetch_batch_size,
|
| 357 |
+
fn_kwargs={
|
| 358 |
+
"num_threads": self.fetch_num_threads,
|
| 359 |
+
"timeout": self.fetch_timeout,
|
| 360 |
+
"retries": self.fetch_retries,
|
| 361 |
+
"sleep_timer": self.fetch_sleep_timer,
|
| 362 |
+
},
|
| 363 |
+
)
|
| 364 |
+
train_dataset = train_dataset.filter(
|
| 365 |
+
lambda example: example["image"] is not None
|
| 366 |
+
)
|
| 367 |
+
self.train_dataset = train_dataset
|
| 368 |
+
self.train_dataset.set_transform(
|
| 369 |
+
partial(
|
| 370 |
+
train_vl_transform,
|
| 371 |
+
dataset=train_dataset.filter(lambda example: True),
|
| 372 |
+
itm_probability=self.itm_probability,
|
| 373 |
+
)
|
| 374 |
+
)
|
| 375 |
+
|
| 376 |
+
val_dataset = build_datasets_from_info(
|
| 377 |
+
self.val_dataset_infos, split="validation"
|
| 378 |
+
)
|
| 379 |
+
|
| 380 |
+
val_dataset = val_dataset.map(
|
| 381 |
+
fetch_images,
|
| 382 |
+
batched=True,
|
| 383 |
+
batch_size=self.fetch_batch_size,
|
| 384 |
+
fn_kwargs={
|
| 385 |
+
"num_threads": self.fetch_num_threads,
|
| 386 |
+
"timeout": self.fetch_timeout,
|
| 387 |
+
"retries": self.fetch_retries,
|
| 388 |
+
"sleep_timer": self.fetch_sleep_timer,
|
| 389 |
+
},
|
| 390 |
+
)
|
| 391 |
+
val_dataset = val_dataset.filter(lambda example: example["image"] is not None)
|
| 392 |
+
self.val_dataset = val_dataset
|
| 393 |
+
self.val_dataset.set_transform(
|
| 394 |
+
partial(
|
| 395 |
+
val_vl_transform,
|
| 396 |
+
dataset=self.val_dataset.filter(
|
| 397 |
+
lambda example: True
|
| 398 |
+
), # Pass a copy to transform
|
| 399 |
+
itm_probability=self.itm_probability,
|
| 400 |
+
)
|
| 401 |
+
)
|
| 402 |
+
|
| 403 |
+
def train_dataloader(self):
|
| 404 |
+
return torch.utils.data.DataLoader(
|
| 405 |
+
self.train_dataset,
|
| 406 |
+
batch_size=self.batch_size,
|
| 407 |
+
num_workers=self.num_workers,
|
| 408 |
+
sampler=get_sampler(self.train_dataset),
|
| 409 |
+
collate_fn=self._build_collator(),
|
| 410 |
+
pin_memory=True,
|
| 411 |
+
persistent_workers=True,
|
| 412 |
+
prefetch_factor=self.prefetch_factor,
|
| 413 |
+
# uneven batches can cause distributed issues,
|
| 414 |
+
# drop last batch to prevent those.
|
| 415 |
+
drop_last=True,
|
| 416 |
+
)
|
| 417 |
+
|
| 418 |
+
def val_dataloader(self):
|
| 419 |
+
return torch.utils.data.DataLoader(
|
| 420 |
+
self.val_dataset,
|
| 421 |
+
batch_size=self.batch_size,
|
| 422 |
+
num_workers=self.num_workers,
|
| 423 |
+
sampler=get_sampler(self.val_dataset, shuffle=False),
|
| 424 |
+
collate_fn=self._build_collator(),
|
| 425 |
+
pin_memory=True,
|
| 426 |
+
persistent_workers=True,
|
| 427 |
+
prefetch_factor=self.prefetch_factor,
|
| 428 |
+
# uneven batches can cause distributed issues,
|
| 429 |
+
# drop last batch to prevent those.
|
| 430 |
+
drop_last=True,
|
| 431 |
+
)
|
| 432 |
+
|
| 433 |
+
def _build_collator(self):
|
| 434 |
+
return DataCollatorForWholeWordMaskRetainingBatch(
|
| 435 |
+
self.text_tokenizer, mlm_probability=self.mlm_probability
|
| 436 |
+
)
|
| 437 |
+
|
| 438 |
+
def on_before_batch_transfer(self, batch, *args):
|
| 439 |
+
batch.pop("token_type_ids", None)
|
| 440 |
+
mask = batch.pop("attention_mask", None)
|
| 441 |
+
if (
|
| 442 |
+
mask is not None
|
| 443 |
+
and mask.size(0) < self.batch_size
|
| 444 |
+
and not self.allow_uneven_batches
|
| 445 |
+
):
|
| 446 |
+
batch = pad_batch(batch, self.batch_size)
|
| 447 |
+
return batch
|
| 448 |
+
|
| 449 |
+
def on_after_batch_transfer(self, batch, *args):
|
| 450 |
+
text_masked = batch.pop("input_ids")
|
| 451 |
+
mlm_labels = batch.pop("labels", None)
|
| 452 |
+
mlm_labels[mlm_labels == -100] = self.ignore_index
|
| 453 |
+
text = text_masked.detach().clone()
|
| 454 |
+
text[mlm_labels != -1] = mlm_labels[mlm_labels != -1]
|
| 455 |
+
batch.update(
|
| 456 |
+
{"mlm_labels": mlm_labels, "text": text, "text_masked": text_masked}
|
| 457 |
+
)
|
| 458 |
+
return batch
|
| 459 |
+
|
| 460 |
+
|
| 461 |
+
class TorchVisionDataModule(LightningDataModule):
|
| 462 |
+
def __init__(
|
| 463 |
+
self,
|
| 464 |
+
train_infos: List[TorchVisionDatasetInfo],
|
| 465 |
+
# Val info is not used for torchvision datamodule, but kept to keep things consistent
|
| 466 |
+
val_infos: Optional[List[TorchVisionDatasetInfo]] = None,
|
| 467 |
+
dataset_root: Optional[str] = None,
|
| 468 |
+
image_transforms: Optional[Tuple[Callable, Callable]] = None,
|
| 469 |
+
batch_size: int = 32,
|
| 470 |
+
num_workers: int = 4,
|
| 471 |
+
prefetch_factor: int = 2,
|
| 472 |
+
**kwargs: Any,
|
| 473 |
+
):
|
| 474 |
+
super().__init__()
|
| 475 |
+
self.train_info = train_infos[0]
|
| 476 |
+
if val_infos is None:
|
| 477 |
+
val_infos = train_infos
|
| 478 |
+
self.val_info = val_infos[0]
|
| 479 |
+
|
| 480 |
+
self.train_class_ptr, self.train_root = self._parse_info(
|
| 481 |
+
self.train_info, dataset_root=dataset_root
|
| 482 |
+
)
|
| 483 |
+
self.val_class_ptr, self.val_root = self._parse_info(
|
| 484 |
+
self.val_info, dataset_root=dataset_root
|
| 485 |
+
)
|
| 486 |
+
|
| 487 |
+
if image_transforms is None:
|
| 488 |
+
image_transforms = default_torchvision_transforms()
|
| 489 |
+
|
| 490 |
+
self.train_transform, self.test_transform = image_transforms
|
| 491 |
+
self.batch_size = batch_size
|
| 492 |
+
self.num_workers = num_workers
|
| 493 |
+
self.prefetch_factor = prefetch_factor
|
| 494 |
+
|
| 495 |
+
def _parse_info(
|
| 496 |
+
self, info: TorchVisionDatasetInfo, dataset_root: Optional[str] = None
|
| 497 |
+
):
|
| 498 |
+
assert hasattr(
|
| 499 |
+
torchvision.datasets, info.key
|
| 500 |
+
), f"No dataset named {info.key} present in torchvision.datasets"
|
| 501 |
+
class_ptr = getattr(torchvision.datasets, info.key)
|
| 502 |
+
if dataset_root is None:
|
| 503 |
+
dataset_root = os.path.join(TRANSFORMERS_CACHE, "datasets", "torchvision")
|
| 504 |
+
dataset_root = os.path.join(dataset_root, class_ptr.__name__.lower())
|
| 505 |
+
os.makedirs(dataset_root, exist_ok=True)
|
| 506 |
+
|
| 507 |
+
return class_ptr, dataset_root
|
| 508 |
+
|
| 509 |
+
def setup(self, stage=None):
|
| 510 |
+
self.train_dataset = self.train_class_ptr(
|
| 511 |
+
self.train_root,
|
| 512 |
+
split=self.train_info.train_split,
|
| 513 |
+
transform=self.train_transform,
|
| 514 |
+
download=True,
|
| 515 |
+
)
|
| 516 |
+
|
| 517 |
+
if self.val_info.has_val:
|
| 518 |
+
self.val_dataset = self.val_class_ptr(
|
| 519 |
+
self.val_root,
|
| 520 |
+
split=self.val_info.val_split,
|
| 521 |
+
transform=self.test_transform,
|
| 522 |
+
download=True,
|
| 523 |
+
)
|
| 524 |
+
|
| 525 |
+
self.test_dataset = self.val_class_ptr(
|
| 526 |
+
self.val_root,
|
| 527 |
+
split=self.val_info.test_split,
|
| 528 |
+
transform=self.test_transform,
|
| 529 |
+
download=True,
|
| 530 |
+
)
|
| 531 |
+
|
| 532 |
+
def train_dataloader(self):
|
| 533 |
+
return self._build_dataloader(self.train_dataset)
|
| 534 |
+
|
| 535 |
+
def val_dataloader(self):
|
| 536 |
+
if self.val_info.has_val:
|
| 537 |
+
dataset = self.val_dataset
|
| 538 |
+
else:
|
| 539 |
+
dataset = self.test_dataset
|
| 540 |
+
|
| 541 |
+
return self._build_dataloader(dataset, shuffle=False)
|
| 542 |
+
|
| 543 |
+
def test_dataloader(self):
|
| 544 |
+
return self._build_dataloader(self.test_dataset, shuffle=False)
|
| 545 |
+
|
| 546 |
+
def _build_dataloader(self, dataset: torch.utils.data.Dataset, shuffle=True):
|
| 547 |
+
return torch.utils.data.DataLoader(
|
| 548 |
+
dataset,
|
| 549 |
+
sampler=get_sampler(dataset, shuffle),
|
| 550 |
+
batch_size=self.batch_size,
|
| 551 |
+
num_workers=self.num_workers,
|
| 552 |
+
pin_memory=True,
|
| 553 |
+
persistent_workers=True,
|
| 554 |
+
prefetch_factor=self.prefetch_factor,
|
| 555 |
+
)
|
| 556 |
+
|
| 557 |
+
def on_before_batch_transfer(self, batch, *args):
|
| 558 |
+
images, targets = batch
|
| 559 |
+
batch = {"image": images, "labels": targets}
|
| 560 |
+
return batch
|
multimodal/examples/flava/native/model.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from typing import Any, Tuple
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from torch import nn
|
| 11 |
+
from torchmultimodal.models.flava.model import flava_model_for_pretraining
|
| 12 |
+
from transformers.optimization import get_cosine_schedule_with_warmup
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def get_optimizer(
|
| 16 |
+
model: torch.nn.Module,
|
| 17 |
+
learning_rate: float = 0.0002,
|
| 18 |
+
adam_eps: float = 1.0e-08,
|
| 19 |
+
adam_weight_decay: float = 0.01,
|
| 20 |
+
adam_betas: Tuple[int, int] = (0.9, 0.999),
|
| 21 |
+
warmup_steps: int = 2000,
|
| 22 |
+
max_steps: int = 450000,
|
| 23 |
+
):
|
| 24 |
+
optimizer = torch.optim.AdamW(
|
| 25 |
+
model.parameters(),
|
| 26 |
+
lr=learning_rate,
|
| 27 |
+
betas=adam_betas,
|
| 28 |
+
eps=adam_eps,
|
| 29 |
+
weight_decay=adam_weight_decay,
|
| 30 |
+
)
|
| 31 |
+
scheduler = get_cosine_schedule_with_warmup(
|
| 32 |
+
optimizer,
|
| 33 |
+
num_warmup_steps=warmup_steps,
|
| 34 |
+
num_training_steps=max_steps,
|
| 35 |
+
)
|
| 36 |
+
return optimizer, scheduler
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class FLAVAPreTrainModule(nn.Module):
|
| 40 |
+
def __init__(
|
| 41 |
+
self,
|
| 42 |
+
use_bf16: bool = True,
|
| 43 |
+
**flava_pretraining_kwargs: Any,
|
| 44 |
+
):
|
| 45 |
+
super().__init__()
|
| 46 |
+
self.model = flava_model_for_pretraining(**flava_pretraining_kwargs)
|
| 47 |
+
self.use_bf16 = use_bf16
|
| 48 |
+
|
| 49 |
+
def forward(self, batch, action=None):
|
| 50 |
+
# super hacky
|
| 51 |
+
if action == "encode_text":
|
| 52 |
+
return self.model.encode_text(batch)
|
| 53 |
+
elif action == "encode_image":
|
| 54 |
+
return self.model.encode_image(batch)
|
| 55 |
+
|
| 56 |
+
if "image" in batch and ("text" in batch or "text_masked" in batch):
|
| 57 |
+
required_embedding = "mm"
|
| 58 |
+
elif "image" in batch:
|
| 59 |
+
required_embedding = "image"
|
| 60 |
+
elif "text" in batch or "text_masked" in batch:
|
| 61 |
+
required_embedding = "text"
|
| 62 |
+
else:
|
| 63 |
+
raise RuntimeError("Batch needs to have either or both 'image' and 'text'.")
|
| 64 |
+
|
| 65 |
+
output = self.model(
|
| 66 |
+
image=batch.get("image"),
|
| 67 |
+
image_for_codebook=batch.get("image_for_codebook"),
|
| 68 |
+
image_patches_mask=batch.get("image_patches_mask"),
|
| 69 |
+
text=batch.get("text"),
|
| 70 |
+
text_masked=batch.get("text_masked"),
|
| 71 |
+
mlm_labels=batch.get("mlm_labels"),
|
| 72 |
+
itm_labels=batch.get("itm_labels"),
|
| 73 |
+
required_embedding=required_embedding,
|
| 74 |
+
)
|
| 75 |
+
return output
|
| 76 |
+
|
| 77 |
+
def encode_text(self, *args, **kwargs):
|
| 78 |
+
return self.model.encode_text(*args, **kwargs)
|
multimodal/examples/flava/native/train.py
ADDED
|
@@ -0,0 +1,415 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# example command to train:
|
| 8 |
+
# `torchrun --nproc_per_node=8 -m flava.native.train config=flava/native/configs/pretrain_debug.yaml`
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
import time
|
| 12 |
+
from functools import partial
|
| 13 |
+
from typing import Any, Dict, Tuple, Union
|
| 14 |
+
|
| 15 |
+
import datasets
|
| 16 |
+
import numpy as np
|
| 17 |
+
import torch
|
| 18 |
+
import torch.distributed as dist
|
| 19 |
+
from common.data import MultiDataModule
|
| 20 |
+
from flava.definitions import FLAVAArguments
|
| 21 |
+
from flava.native.data import (
|
| 22 |
+
default_text_transform,
|
| 23 |
+
ImageDataModule,
|
| 24 |
+
MLMDataModule,
|
| 25 |
+
VL_MAX_LENGTH_DEFAULT,
|
| 26 |
+
VLDataModule,
|
| 27 |
+
)
|
| 28 |
+
from flava.native.model import FLAVAPreTrainModule, get_optimizer
|
| 29 |
+
from flava.native.utils import (
|
| 30 |
+
build_config,
|
| 31 |
+
enable_tf32,
|
| 32 |
+
get_model_parameters,
|
| 33 |
+
get_model_size_gb,
|
| 34 |
+
move_to_device,
|
| 35 |
+
print0,
|
| 36 |
+
run_imagenet_zero_shot,
|
| 37 |
+
set_seed,
|
| 38 |
+
setup_distributed_device,
|
| 39 |
+
)
|
| 40 |
+
from flava.utils import build_datamodule_kwargs
|
| 41 |
+
|
| 42 |
+
from omegaconf import DictConfig, OmegaConf
|
| 43 |
+
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
|
| 44 |
+
apply_activation_checkpointing,
|
| 45 |
+
checkpoint_wrapper,
|
| 46 |
+
CheckpointImpl,
|
| 47 |
+
)
|
| 48 |
+
from torch.distributed.elastic.multiprocessing.errors import record
|
| 49 |
+
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, MixedPrecision
|
| 50 |
+
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
|
| 51 |
+
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
|
| 52 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 53 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 54 |
+
from torchmultimodal.models.flava.image_encoder import ImageTransformer
|
| 55 |
+
from torchmultimodal.models.flava.text_encoder import BERTTextEncoder
|
| 56 |
+
from torchmultimodal.models.flava.transformer import (
|
| 57 |
+
FLAVATransformerWithoutEmbeddings,
|
| 58 |
+
TransformerEncoderLayer,
|
| 59 |
+
)
|
| 60 |
+
from torchmultimodal.modules.losses.flava import FLAVAPretrainingLossOutput
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def get_datamodules(config: FLAVAArguments) -> Tuple[MultiDataModule, ImageDataModule]:
|
| 64 |
+
datamodules = []
|
| 65 |
+
|
| 66 |
+
# also needed for the imagenet eval callback
|
| 67 |
+
imagenet_datamodule = ImageDataModule(
|
| 68 |
+
**build_datamodule_kwargs(config.datasets.image, config.training)
|
| 69 |
+
)
|
| 70 |
+
for dataset in config.datasets.selected:
|
| 71 |
+
if dataset == "image":
|
| 72 |
+
datamodules.append(imagenet_datamodule)
|
| 73 |
+
elif dataset == "text":
|
| 74 |
+
datamodules.append(
|
| 75 |
+
MLMDataModule(
|
| 76 |
+
**build_datamodule_kwargs(config.datasets.text, config.training)
|
| 77 |
+
)
|
| 78 |
+
)
|
| 79 |
+
elif dataset == "vl":
|
| 80 |
+
datamodules.append(
|
| 81 |
+
VLDataModule(
|
| 82 |
+
**build_datamodule_kwargs(config.datasets.vl, config.training)
|
| 83 |
+
)
|
| 84 |
+
)
|
| 85 |
+
else:
|
| 86 |
+
raise ValueError(f"unknown dataset: {dataset}")
|
| 87 |
+
|
| 88 |
+
return MultiDataModule(datamodules), imagenet_datamodule
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
@record
|
| 92 |
+
class Trainer:
|
| 93 |
+
def __init__(self, config: DictConfig):
|
| 94 |
+
if config.training.seed != -1:
|
| 95 |
+
set_seed(config.training.seed)
|
| 96 |
+
|
| 97 |
+
self.device: torch.device = setup_distributed_device()
|
| 98 |
+
self.config: DictConfig = config
|
| 99 |
+
self.rank: int = dist.get_rank()
|
| 100 |
+
self._logger: SummaryWriter = SummaryWriter(
|
| 101 |
+
f"logs/{config.training.strategy}/{int(time.time())}"
|
| 102 |
+
)
|
| 103 |
+
self.steps: int = -1
|
| 104 |
+
self.epochs: int = -1
|
| 105 |
+
|
| 106 |
+
multi_module, image_module = get_datamodules(config)
|
| 107 |
+
|
| 108 |
+
self.datamodule: MultiDataModule = multi_module
|
| 109 |
+
self.datamodule.setup("fit")
|
| 110 |
+
|
| 111 |
+
self.imagenet_val_dataloader = image_module.val_dataloader()
|
| 112 |
+
self.imagenet_val_text_transform = default_text_transform(
|
| 113 |
+
max_text_length=VL_MAX_LENGTH_DEFAULT
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
self.half_dtype = (
|
| 117 |
+
torch.bfloat16
|
| 118 |
+
if config.training.half_precision_format == "bfloat16"
|
| 119 |
+
else torch.float16
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
self.scaler = ShardedGradScaler() if config.training.enable_amp else None
|
| 123 |
+
|
| 124 |
+
def log(
|
| 125 |
+
self,
|
| 126 |
+
name: str,
|
| 127 |
+
value: Union[torch.Tensor, float, int],
|
| 128 |
+
log_rank_0: bool = True,
|
| 129 |
+
always_log: bool = False,
|
| 130 |
+
):
|
| 131 |
+
if log_rank_0 and self.rank != 0:
|
| 132 |
+
return
|
| 133 |
+
|
| 134 |
+
if always_log or self.steps % self.config.training.log_interval == 0:
|
| 135 |
+
self._logger.add_scalar(name, value, self.steps)
|
| 136 |
+
|
| 137 |
+
def create_model(self) -> torch.nn.Module:
|
| 138 |
+
model_config = self.config.get("model", {})
|
| 139 |
+
print0(f"using model config: {model_config}")
|
| 140 |
+
|
| 141 |
+
model = FLAVAPreTrainModule(**model_config)
|
| 142 |
+
strategy = self.config.training.strategy
|
| 143 |
+
|
| 144 |
+
print0(
|
| 145 |
+
f"before {strategy} model parameters: {get_model_parameters(model):,}, "
|
| 146 |
+
f"size: {get_model_size_gb(model):.3} GB"
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
if self.config.training.activation_checkpointing:
|
| 150 |
+
check_fn = lambda submodule: isinstance(submodule, TransformerEncoderLayer)
|
| 151 |
+
checkpoint_impl = CheckpointImpl.REENTRANT
|
| 152 |
+
|
| 153 |
+
# DDP gradient hooks have compatibility issues with REENTRANT autograd
|
| 154 |
+
if strategy == "ddp":
|
| 155 |
+
checkpoint_impl = CheckpointImpl.NO_REENTRANT
|
| 156 |
+
|
| 157 |
+
checkpoint_wrapper_fn = partial(
|
| 158 |
+
checkpoint_wrapper,
|
| 159 |
+
offload_to_cpu=False,
|
| 160 |
+
checkpoint_impl=checkpoint_impl,
|
| 161 |
+
)
|
| 162 |
+
apply_activation_checkpointing(
|
| 163 |
+
model,
|
| 164 |
+
checkpoint_wrapper_fn=checkpoint_wrapper_fn,
|
| 165 |
+
check_fn=check_fn,
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
if strategy == "ddp":
|
| 169 |
+
# TODO do we have to do this in FSDP too? see https://github.com/pytorch/pytorch/issues/75478
|
| 170 |
+
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
| 171 |
+
model = model.to(self.device)
|
| 172 |
+
|
| 173 |
+
print0(
|
| 174 |
+
f"after moving to cuda: {torch.cuda.memory_allocated()/1024**3:.3} GB"
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
model = DDP(
|
| 178 |
+
model,
|
| 179 |
+
device_ids=[self.rank],
|
| 180 |
+
find_unused_parameters=True,
|
| 181 |
+
gradient_as_bucket_view=True,
|
| 182 |
+
)
|
| 183 |
+
print0(f"after DDP: {torch.cuda.memory_allocated()/1024**3:.3} GB")
|
| 184 |
+
elif strategy == "fsdp":
|
| 185 |
+
mp = None
|
| 186 |
+
if self.config.training.enable_half_reduce_in_fsdp:
|
| 187 |
+
mp = MixedPrecision(
|
| 188 |
+
# param_dtype=self.half_dtype, not working
|
| 189 |
+
reduce_dtype=self.half_dtype,
|
| 190 |
+
# buffer_dtype=self.half_dtype,
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
model = FSDP(
|
| 194 |
+
model,
|
| 195 |
+
mixed_precision=mp,
|
| 196 |
+
device_id=self.device,
|
| 197 |
+
auto_wrap_policy=partial(
|
| 198 |
+
transformer_auto_wrap_policy,
|
| 199 |
+
transformer_layer_cls={
|
| 200 |
+
TransformerEncoderLayer,
|
| 201 |
+
ImageTransformer,
|
| 202 |
+
BERTTextEncoder,
|
| 203 |
+
FLAVATransformerWithoutEmbeddings,
|
| 204 |
+
},
|
| 205 |
+
),
|
| 206 |
+
limit_all_gathers=True,
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
print0(f"after FSDP {torch.cuda.memory_allocated()/1024**3:.3} GB")
|
| 210 |
+
|
| 211 |
+
else:
|
| 212 |
+
raise ValueError(f"unknown strategy: {strategy}")
|
| 213 |
+
|
| 214 |
+
print0(
|
| 215 |
+
f"after {strategy} model parameters: {get_model_parameters(model):,}, "
|
| 216 |
+
f"size: {get_model_size_gb(model):.3} GB"
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
return model
|
| 220 |
+
|
| 221 |
+
def calculate_loss(
|
| 222 |
+
self, output: FLAVAPretrainingLossOutput, validation=False
|
| 223 |
+
) -> torch.Tensor:
|
| 224 |
+
losses = output.losses
|
| 225 |
+
|
| 226 |
+
total_loss = 0
|
| 227 |
+
for key in losses:
|
| 228 |
+
if losses[key] is not None:
|
| 229 |
+
total_loss += losses[key]
|
| 230 |
+
loss_reduce = losses[key].detach()
|
| 231 |
+
dist.reduce(loss_reduce, dst=0)
|
| 232 |
+
if validation:
|
| 233 |
+
mode = "validation"
|
| 234 |
+
else:
|
| 235 |
+
mode = "train"
|
| 236 |
+
self.log(
|
| 237 |
+
f"{mode}/losses/{key}",
|
| 238 |
+
loss_reduce.item() / dist.get_world_size(),
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
return total_loss
|
| 242 |
+
|
| 243 |
+
def preprocess_data(self, data: Dict[str, Any]):
|
| 244 |
+
data = self.datamodule.on_before_batch_transfer(data, None)
|
| 245 |
+
data = move_to_device(data, self.device)
|
| 246 |
+
return self.datamodule.on_after_batch_transfer(data, None)
|
| 247 |
+
|
| 248 |
+
def _log_iteration_times(self, iteration_times):
|
| 249 |
+
profile_warmup_steps = config.get("profile_warmup_steps", 100)
|
| 250 |
+
start_idx = (
|
| 251 |
+
profile_warmup_steps
|
| 252 |
+
if profile_warmup_steps < self.config.training.max_steps
|
| 253 |
+
else 0
|
| 254 |
+
)
|
| 255 |
+
iteration_times = iteration_times[start_idx:]
|
| 256 |
+
avg_it_time = np.mean(iteration_times)
|
| 257 |
+
avg_throughput = (
|
| 258 |
+
config.training.batch_size * dist.get_world_size()
|
| 259 |
+
) / avg_it_time
|
| 260 |
+
print0(f"Average over {len(iteration_times)} steps")
|
| 261 |
+
print0(f"Average iteration time {round(avg_it_time,4)}")
|
| 262 |
+
print0(f"Average throughput {round(avg_throughput,4)}")
|
| 263 |
+
|
| 264 |
+
def train(self) -> None:
|
| 265 |
+
print0(OmegaConf.to_container(self.config.training))
|
| 266 |
+
self.model = self.create_model()
|
| 267 |
+
model = self.model
|
| 268 |
+
|
| 269 |
+
optimizer, scheduler = get_optimizer(
|
| 270 |
+
model,
|
| 271 |
+
**self.config.training.optimizer,
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
iteration_times = []
|
| 275 |
+
|
| 276 |
+
while True:
|
| 277 |
+
t0 = time.time()
|
| 278 |
+
self.epochs += 1
|
| 279 |
+
dataloader = self.datamodule.train_dataloader()
|
| 280 |
+
dataloader.set_epoch(self.epochs)
|
| 281 |
+
|
| 282 |
+
for i, data in enumerate(dataloader):
|
| 283 |
+
torch.cuda.reset_peak_memory_stats()
|
| 284 |
+
|
| 285 |
+
self.steps += 1
|
| 286 |
+
|
| 287 |
+
if self.config.training.max_steps < self.steps:
|
| 288 |
+
if self.rank == 0:
|
| 289 |
+
self._log_iteration_times(iteration_times)
|
| 290 |
+
print0("Max steps reached, exiting")
|
| 291 |
+
return
|
| 292 |
+
|
| 293 |
+
model.train()
|
| 294 |
+
data = self.preprocess_data(data)
|
| 295 |
+
optimizer.zero_grad(set_to_none=True)
|
| 296 |
+
|
| 297 |
+
with torch.cuda.amp.autocast(
|
| 298 |
+
dtype=self.half_dtype, enabled=bool(self.scaler)
|
| 299 |
+
):
|
| 300 |
+
output = model(data)
|
| 301 |
+
print0(
|
| 302 |
+
f"after forward pass {torch.cuda.memory_allocated()/1024**3:.3} GB"
|
| 303 |
+
)
|
| 304 |
+
self.log(
|
| 305 |
+
"stats/fwd memory alloc",
|
| 306 |
+
torch.cuda.memory_allocated() / 1024**3,
|
| 307 |
+
)
|
| 308 |
+
self.log(
|
| 309 |
+
"stats/fwd memory reserved",
|
| 310 |
+
torch.cuda.memory_reserved() / 1024**3,
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
total_loss = self.calculate_loss(output)
|
| 314 |
+
|
| 315 |
+
if self.scaler:
|
| 316 |
+
self.scaler.scale(total_loss).backward()
|
| 317 |
+
self.scaler.step(optimizer)
|
| 318 |
+
self.scaler.update()
|
| 319 |
+
else:
|
| 320 |
+
total_loss.backward()
|
| 321 |
+
optimizer.step()
|
| 322 |
+
|
| 323 |
+
scheduler.step()
|
| 324 |
+
torch.cuda.synchronize()
|
| 325 |
+
t1 = time.time()
|
| 326 |
+
batch_time = t1 - t0
|
| 327 |
+
batch_size = config.training.batch_size * dist.get_world_size()
|
| 328 |
+
items_time = batch_size / (t1 - t0)
|
| 329 |
+
|
| 330 |
+
t0 = t1
|
| 331 |
+
self.log("stats/sec per batch", batch_time)
|
| 332 |
+
self.log("stats/items per sec", items_time)
|
| 333 |
+
|
| 334 |
+
total_loss = total_loss.detach()
|
| 335 |
+
dist.reduce(total_loss, dst=0)
|
| 336 |
+
|
| 337 |
+
if self.rank == 0:
|
| 338 |
+
norm_total_loss = total_loss.item() / dist.get_world_size()
|
| 339 |
+
|
| 340 |
+
print(
|
| 341 |
+
f"epoch: {self.epochs} step {self.steps} loss: {norm_total_loss:.4}"
|
| 342 |
+
)
|
| 343 |
+
self.log("train/loss", norm_total_loss)
|
| 344 |
+
self.log("stats/batch_size", batch_size)
|
| 345 |
+
|
| 346 |
+
iteration_times.append(batch_time)
|
| 347 |
+
|
| 348 |
+
cuda_info = torch.cuda.memory_stats()
|
| 349 |
+
print("cuda alloc retries ", cuda_info.get("num_alloc_retries", 0))
|
| 350 |
+
|
| 351 |
+
self.log(
|
| 352 |
+
"stats/max_gpu_allocated_gb",
|
| 353 |
+
torch.cuda.max_memory_allocated() / 1024**3,
|
| 354 |
+
)
|
| 355 |
+
# TODO implement imagenet eval
|
| 356 |
+
# TODO implement checkpoint saving
|
| 357 |
+
|
| 358 |
+
self.validate()
|
| 359 |
+
|
| 360 |
+
def validate(self):
|
| 361 |
+
if self.steps % self.config.training.validation_steps != 0 or self.steps == 0:
|
| 362 |
+
return
|
| 363 |
+
|
| 364 |
+
model = self.model
|
| 365 |
+
model.eval()
|
| 366 |
+
print0("evaluating")
|
| 367 |
+
|
| 368 |
+
validation_loader = self.datamodule.val_dataloader()
|
| 369 |
+
validation_loss = torch.Tensor([0]).to(self.device)
|
| 370 |
+
|
| 371 |
+
for data in validation_loader:
|
| 372 |
+
data = self.preprocess_data(data)
|
| 373 |
+
with torch.no_grad():
|
| 374 |
+
with torch.cuda.amp.autocast(
|
| 375 |
+
dtype=self.half_dtype, enabled=bool(self.scaler)
|
| 376 |
+
):
|
| 377 |
+
output = model(data)
|
| 378 |
+
total_loss = self.calculate_loss(output, validation=True)
|
| 379 |
+
validation_loss += total_loss.detach()
|
| 380 |
+
|
| 381 |
+
dist.reduce(validation_loss, dst=0)
|
| 382 |
+
norm_validation_loss = validation_loss.item() / dist.get_world_size()
|
| 383 |
+
|
| 384 |
+
print0(f"step {self.steps} EVAL loss: {norm_validation_loss:.4}")
|
| 385 |
+
|
| 386 |
+
def imagenet_validate(self):
|
| 387 |
+
print0("imagenet validation")
|
| 388 |
+
with torch.no_grad():
|
| 389 |
+
with torch.cuda.amp.autocast(
|
| 390 |
+
dtype=self.half_dtype, enabled=bool(self.scaler)
|
| 391 |
+
):
|
| 392 |
+
metrics = run_imagenet_zero_shot(
|
| 393 |
+
self.model,
|
| 394 |
+
self.imagenet_val_dataloader,
|
| 395 |
+
self.device,
|
| 396 |
+
self.imagenet_val_text_transform,
|
| 397 |
+
)
|
| 398 |
+
if metrics is not None:
|
| 399 |
+
for key in metrics:
|
| 400 |
+
self.log(
|
| 401 |
+
f"val/imagenet/{key}",
|
| 402 |
+
metrics[key],
|
| 403 |
+
always_log=True,
|
| 404 |
+
)
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
if __name__ == "__main__":
|
| 408 |
+
datasets.logging.set_verbosity_error() # too spammy
|
| 409 |
+
|
| 410 |
+
config: FLAVAArguments = build_config()
|
| 411 |
+
if config.training.enable_tf32:
|
| 412 |
+
enable_tf32()
|
| 413 |
+
|
| 414 |
+
trainer = Trainer(config)
|
| 415 |
+
trainer.train()
|
multimodal/examples/flava/native/utils.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import random
|
| 9 |
+
from typing import Any
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
from flava.data.imagenet_zeroshot_data import (
|
| 13 |
+
imagenet_classnames,
|
| 14 |
+
openai_imagenet_template,
|
| 15 |
+
)
|
| 16 |
+
from hydra.utils import instantiate
|
| 17 |
+
from omegaconf import DictConfig, OmegaConf
|
| 18 |
+
from torch import distributed as dist
|
| 19 |
+
from tqdm import tqdm
|
| 20 |
+
|
| 21 |
+
# optional syntax-highlighting for console output
|
| 22 |
+
try:
|
| 23 |
+
from rich.console import Console
|
| 24 |
+
|
| 25 |
+
c = Console(force_terminal=True)
|
| 26 |
+
print = c.log
|
| 27 |
+
except ImportError:
|
| 28 |
+
pass
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def build_config() -> DictConfig:
|
| 32 |
+
cli_conf = OmegaConf.from_cli()
|
| 33 |
+
yaml_conf = OmegaConf.load(cli_conf.config)
|
| 34 |
+
conf = instantiate(yaml_conf)
|
| 35 |
+
conf = OmegaConf.merge(conf, cli_conf)
|
| 36 |
+
return conf
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
# TODO replace with tlc.copy_data_to_device
|
| 40 |
+
def move_to_device(obj: Any, device: torch.device) -> Any:
|
| 41 |
+
if isinstance(obj, dict):
|
| 42 |
+
d = {}
|
| 43 |
+
for k, v in obj.items():
|
| 44 |
+
d[k] = move_to_device(v, device)
|
| 45 |
+
return d
|
| 46 |
+
if isinstance(obj, list):
|
| 47 |
+
l = []
|
| 48 |
+
for v in obj:
|
| 49 |
+
l.append(move_to_device(v, device))
|
| 50 |
+
return l
|
| 51 |
+
|
| 52 |
+
return obj.to(device)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def get_model_size_gb(model: torch.nn.Module) -> int:
|
| 56 |
+
return sum(p.numel() * p.element_size() for p in model.parameters()) / (1024**3)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def get_model_parameters(model: torch.nn.Module) -> int:
|
| 60 |
+
return sum(p.numel() for p in model.parameters())
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def set_seed(seed: int) -> None:
|
| 64 |
+
torch.manual_seed(seed)
|
| 65 |
+
random.seed(seed)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def setup_distributed_device() -> torch.device:
|
| 69 |
+
if not torch.cuda.is_available() or not dist.is_available():
|
| 70 |
+
return torch.device("cpu")
|
| 71 |
+
|
| 72 |
+
dist.init_process_group("nccl")
|
| 73 |
+
local_rank = int(os.environ["LOCAL_RANK"])
|
| 74 |
+
print("local rank", local_rank)
|
| 75 |
+
torch.cuda.set_device(local_rank)
|
| 76 |
+
return torch.device(f"cuda:{local_rank}")
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def print0(*args, **kwargs) -> None:
|
| 80 |
+
if not dist.is_initialized() or dist.get_rank() == 0:
|
| 81 |
+
print(*args, **kwargs)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def enable_tf32() -> None:
|
| 85 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 86 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def rank0_only(func):
|
| 90 |
+
def wrapper(*args, **kwargs):
|
| 91 |
+
if not dist.is_initialized() or dist.get_rank() == 0:
|
| 92 |
+
return func(*args, **kwargs)
|
| 93 |
+
|
| 94 |
+
return wrapper
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
# zero shot classifier functions
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def _zero_shot_classifier(model, device, text_transform, *args, **kwargs):
|
| 101 |
+
zeroshot_weights = []
|
| 102 |
+
for classname in tqdm(imagenet_classnames):
|
| 103 |
+
texts = text_transform(
|
| 104 |
+
[template(classname) for template in openai_imagenet_template]
|
| 105 |
+
)["input_ids"]
|
| 106 |
+
texts = texts.to(device)
|
| 107 |
+
class_embeddings = model(texts, action="encode_text")
|
| 108 |
+
class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
|
| 109 |
+
class_embedding = class_embeddings.mean(dim=0)
|
| 110 |
+
class_embedding /= class_embedding.norm()
|
| 111 |
+
zeroshot_weights.append(class_embedding)
|
| 112 |
+
|
| 113 |
+
zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(device)
|
| 114 |
+
return zeroshot_weights
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def _accuracy(output, target, topk=(1,)):
|
| 118 |
+
pred = output.topk(max(topk), 1, True, True)[1].t()
|
| 119 |
+
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
| 120 |
+
return [
|
| 121 |
+
float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy())
|
| 122 |
+
for k in topk
|
| 123 |
+
]
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def run_imagenet_zero_shot(model, dataloader, device, text_transform, *args, **kwargs):
|
| 127 |
+
print0("Starting ImageNet Zero-Shot Eval")
|
| 128 |
+
print0("Building classifier")
|
| 129 |
+
classifier = _zero_shot_classifier(model, device, text_transform)
|
| 130 |
+
print0("Classifier built")
|
| 131 |
+
top1, top5, n = 0.0, 0.0, 0.0
|
| 132 |
+
for i, sample in tqdm(enumerate(dataloader)):
|
| 133 |
+
images = sample["image"]
|
| 134 |
+
target = sample["label"]
|
| 135 |
+
images = images.to(device)
|
| 136 |
+
target = target.to(device)
|
| 137 |
+
|
| 138 |
+
# predict
|
| 139 |
+
# if hasattr(model, "module"):
|
| 140 |
+
# image_features = model.module.encode_image({"image": images})
|
| 141 |
+
# else:
|
| 142 |
+
image_features = model(images, action="encode_image")
|
| 143 |
+
image_features /= image_features.norm(dim=-1, keepdim=True)
|
| 144 |
+
logits = 100.0 * image_features @ classifier
|
| 145 |
+
|
| 146 |
+
# measure accuracy
|
| 147 |
+
acc1, acc5 = _accuracy(logits, target, topk=(1, 5))
|
| 148 |
+
top1 += acc1
|
| 149 |
+
top5 += acc5
|
| 150 |
+
n += images.size(0)
|
| 151 |
+
if i == 5:
|
| 152 |
+
break
|
| 153 |
+
|
| 154 |
+
top1 = top1 / n
|
| 155 |
+
top5 = top5 / n
|
| 156 |
+
results = {}
|
| 157 |
+
results["imagenet-zeroshot-val-top1"] = top1
|
| 158 |
+
results["imagenet-zeroshot-val-top5"] = top5
|
| 159 |
+
print0("results: ", results)
|
| 160 |
+
return results
|
multimodal/examples/flava/notebooks/RemapFLAVACheckpoint.ipynb
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"id": "7cc982d1",
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"source": [
|
| 8 |
+
"# Re-map FLAVA checkpoint\n",
|
| 9 |
+
"\n",
|
| 10 |
+
"Modifying FLAVA's components can cause existing model checkpoints to go out of sync with the updated architecture. This notebook shows how to load the existing checkpoint, re-map the old layers to the new layers, and save the new checkpoint.\n",
|
| 11 |
+
"\n",
|
| 12 |
+
"To upload a new checkpoint, you must have access to the PyTorch AWS S3 account, and manually upload it from a local copy."
|
| 13 |
+
]
|
| 14 |
+
},
|
| 15 |
+
{
|
| 16 |
+
"cell_type": "markdown",
|
| 17 |
+
"id": "411e4191",
|
| 18 |
+
"metadata": {},
|
| 19 |
+
"source": [
|
| 20 |
+
"### Load original model\n",
|
| 21 |
+
"\n",
|
| 22 |
+
"Load the existing checkpoint into the FLAVA class to see what the architecture currently is."
|
| 23 |
+
]
|
| 24 |
+
},
|
| 25 |
+
{
|
| 26 |
+
"cell_type": "code",
|
| 27 |
+
"execution_count": 3,
|
| 28 |
+
"id": "88ee917b",
|
| 29 |
+
"metadata": {},
|
| 30 |
+
"outputs": [],
|
| 31 |
+
"source": [
|
| 32 |
+
"import torch\n",
|
| 33 |
+
"from torchmultimodal.models.flava.model import flava_model_for_classification, flava_model_for_pretraining\n",
|
| 34 |
+
"\n",
|
| 35 |
+
"# flava_classification = flava_model_for_classification(num_classes=3)\n",
|
| 36 |
+
"flava_pretraining = flava_model_for_pretraining(pretrained_model_key='flava_full')"
|
| 37 |
+
]
|
| 38 |
+
},
|
| 39 |
+
{
|
| 40 |
+
"cell_type": "markdown",
|
| 41 |
+
"id": "5f00b369",
|
| 42 |
+
"metadata": {},
|
| 43 |
+
"source": [
|
| 44 |
+
"### Print summary"
|
| 45 |
+
]
|
| 46 |
+
},
|
| 47 |
+
{
|
| 48 |
+
"cell_type": "code",
|
| 49 |
+
"execution_count": null,
|
| 50 |
+
"id": "cc286394",
|
| 51 |
+
"metadata": {},
|
| 52 |
+
"outputs": [],
|
| 53 |
+
"source": [
|
| 54 |
+
"flava_pretraining"
|
| 55 |
+
]
|
| 56 |
+
},
|
| 57 |
+
{
|
| 58 |
+
"cell_type": "markdown",
|
| 59 |
+
"id": "0d774455",
|
| 60 |
+
"metadata": {},
|
| 61 |
+
"source": [
|
| 62 |
+
"### Mapping function\n",
|
| 63 |
+
"\n",
|
| 64 |
+
"Replace this function with the code needed to map the old layer weights to the new layer weights."
|
| 65 |
+
]
|
| 66 |
+
},
|
| 67 |
+
{
|
| 68 |
+
"cell_type": "code",
|
| 69 |
+
"execution_count": 4,
|
| 70 |
+
"id": "cc9e4537",
|
| 71 |
+
"metadata": {},
|
| 72 |
+
"outputs": [],
|
| 73 |
+
"source": [
|
| 74 |
+
"import re\n",
|
| 75 |
+
"\n",
|
| 76 |
+
"def map_state_dict(state_dict):\n",
|
| 77 |
+
" mapped_state_dict = {}\n",
|
| 78 |
+
" for param, val in state_dict.items():\n",
|
| 79 |
+
" res = re.search('attention.attention', param)\n",
|
| 80 |
+
" if res:\n",
|
| 81 |
+
" idx = res.start()\n",
|
| 82 |
+
" new_param = param[:idx] + param[idx+10:]\n",
|
| 83 |
+
" else:\n",
|
| 84 |
+
" new_param = param\n",
|
| 85 |
+
" mapped_state_dict[new_param] = val\n",
|
| 86 |
+
" return mapped_state_dict"
|
| 87 |
+
]
|
| 88 |
+
},
|
| 89 |
+
{
|
| 90 |
+
"cell_type": "markdown",
|
| 91 |
+
"id": "29870590",
|
| 92 |
+
"metadata": {},
|
| 93 |
+
"source": [
|
| 94 |
+
"### Load old state dict"
|
| 95 |
+
]
|
| 96 |
+
},
|
| 97 |
+
{
|
| 98 |
+
"cell_type": "code",
|
| 99 |
+
"execution_count": 5,
|
| 100 |
+
"id": "41f64d26",
|
| 101 |
+
"metadata": {},
|
| 102 |
+
"outputs": [],
|
| 103 |
+
"source": [
|
| 104 |
+
"# Load from url, replace this path if it changes\n",
|
| 105 |
+
"# old_model_url = 'https://download.pytorch.org/models/multimodal/flava/flava_model.pt'\n",
|
| 106 |
+
"# old_state_dict = torch.hub.load_state_dict_from_url(old_model_url)\n",
|
| 107 |
+
"\n",
|
| 108 |
+
"# Or get from loaded model\n",
|
| 109 |
+
"old_state_dict = flava_pretraining.model.state_dict()"
|
| 110 |
+
]
|
| 111 |
+
},
|
| 112 |
+
{
|
| 113 |
+
"cell_type": "markdown",
|
| 114 |
+
"id": "75322113",
|
| 115 |
+
"metadata": {},
|
| 116 |
+
"source": [
|
| 117 |
+
"### Perform re-mapping"
|
| 118 |
+
]
|
| 119 |
+
},
|
| 120 |
+
{
|
| 121 |
+
"cell_type": "code",
|
| 122 |
+
"execution_count": 6,
|
| 123 |
+
"id": "17363ae8",
|
| 124 |
+
"metadata": {},
|
| 125 |
+
"outputs": [],
|
| 126 |
+
"source": [
|
| 127 |
+
"#new_state_dict = map_state_dict(old_state_dict)\n",
|
| 128 |
+
"new_state_dict = old_state_dict"
|
| 129 |
+
]
|
| 130 |
+
},
|
| 131 |
+
{
|
| 132 |
+
"cell_type": "markdown",
|
| 133 |
+
"id": "d94c4133",
|
| 134 |
+
"metadata": {},
|
| 135 |
+
"source": [
|
| 136 |
+
"### Save updated checkpoint"
|
| 137 |
+
]
|
| 138 |
+
},
|
| 139 |
+
{
|
| 140 |
+
"cell_type": "code",
|
| 141 |
+
"execution_count": 7,
|
| 142 |
+
"id": "bc6baad9",
|
| 143 |
+
"metadata": {},
|
| 144 |
+
"outputs": [],
|
| 145 |
+
"source": [
|
| 146 |
+
"save_path = '/Users/rafiayub/flava_model.pt'\n",
|
| 147 |
+
"torch.save(new_state_dict, save_path)"
|
| 148 |
+
]
|
| 149 |
+
}
|
| 150 |
+
],
|
| 151 |
+
"metadata": {
|
| 152 |
+
"kernelspec": {
|
| 153 |
+
"display_name": "Python 3 (ipykernel)",
|
| 154 |
+
"language": "python",
|
| 155 |
+
"name": "python3"
|
| 156 |
+
},
|
| 157 |
+
"language_info": {
|
| 158 |
+
"codemirror_mode": {
|
| 159 |
+
"name": "ipython",
|
| 160 |
+
"version": 3
|
| 161 |
+
},
|
| 162 |
+
"file_extension": ".py",
|
| 163 |
+
"mimetype": "text/x-python",
|
| 164 |
+
"name": "python",
|
| 165 |
+
"nbconvert_exporter": "python",
|
| 166 |
+
"pygments_lexer": "ipython3",
|
| 167 |
+
"version": "3.9.12"
|
| 168 |
+
}
|
| 169 |
+
},
|
| 170 |
+
"nbformat": 4,
|
| 171 |
+
"nbformat_minor": 5
|
| 172 |
+
}
|
multimodal/examples/flava/tools/convert_weights.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import argparse
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from torchmultimodal.models.flava.model import flava_model_for_pretraining
|
| 11 |
+
|
| 12 |
+
KEY_REPLACEMENTS = {
|
| 13 |
+
"image_encoder.module": "image_encoder",
|
| 14 |
+
"text_encoder.module": "text_encoder",
|
| 15 |
+
"mm_encoder.module": "mm_encoder",
|
| 16 |
+
"mm_encoder.encoder.cls_token": "mm_encoder.cls_token",
|
| 17 |
+
"mm_image_projection": "image_to_mm_projection",
|
| 18 |
+
"mm_text_projection": "text_to_mm_projection",
|
| 19 |
+
"model.heads.cmd.mim_head": "loss.mmm_loss.mim",
|
| 20 |
+
"model.heads.cmd.mlm_head": "loss.mmm_loss.mlm",
|
| 21 |
+
"model.heads.fairseq_mlm": "loss.mlm_loss",
|
| 22 |
+
"model.heads.imagenet.mim_head": "loss.mim_loss",
|
| 23 |
+
"cls.predictions.transform": "cls",
|
| 24 |
+
"cls.predictions": "cls",
|
| 25 |
+
"cls.LayerNorm": "cls.layer_norm",
|
| 26 |
+
"model.text_projection": "loss.contrastive_loss.text_projection",
|
| 27 |
+
"model.image_projection": "loss.contrastive_loss.image_projection",
|
| 28 |
+
"model.heads.cmd.clip_head.logit_scale": "loss.contrastive_loss.logit_scale",
|
| 29 |
+
"model.heads.cmd.itm_head": "loss.itm_loss",
|
| 30 |
+
"intermediate.dense": "intermediate",
|
| 31 |
+
"output.dense": "output",
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def convert_weights(args):
|
| 36 |
+
ckpt = torch.load(args.ckpt_file, map_location="cpu")
|
| 37 |
+
flava = flava_model_for_pretraining()
|
| 38 |
+
model = ckpt["model"]
|
| 39 |
+
import pdb
|
| 40 |
+
|
| 41 |
+
pdb.set_trace()
|
| 42 |
+
for key in list(model.keys()):
|
| 43 |
+
original = key
|
| 44 |
+
for option, replacement in KEY_REPLACEMENTS.items():
|
| 45 |
+
key = key.replace(option, replacement)
|
| 46 |
+
model[key] = model.pop(original)
|
| 47 |
+
|
| 48 |
+
if args.add_codebook:
|
| 49 |
+
# Since codebook is anyways not trained in FLAVA pretraining
|
| 50 |
+
# we can use the pretrained one that we get from FLAVA initialized
|
| 51 |
+
# model
|
| 52 |
+
model.update(
|
| 53 |
+
{
|
| 54 |
+
f"image_codebook.{key}": value
|
| 55 |
+
for key, value in flava.image_codebook.state_dict().items()
|
| 56 |
+
}
|
| 57 |
+
)
|
| 58 |
+
flava.load_state_dict(model)
|
| 59 |
+
|
| 60 |
+
# Let's save the model now.
|
| 61 |
+
torch.save(flava.state_dict(), args.save_file)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
if __name__ == "__main__":
|
| 65 |
+
parser = argparse.ArgumentParser(description="Convert weights")
|
| 66 |
+
parser.add_argument("ckpt_file", type=str)
|
| 67 |
+
parser.add_argument("save_file", type=str)
|
| 68 |
+
parser.add_argument("--add_codebook", action="store_true")
|
| 69 |
+
|
| 70 |
+
args = parser.parse_args()
|
| 71 |
+
|
| 72 |
+
convert_weights(args)
|
multimodal/examples/mugen/data/README.md
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
This folder contains code for interfacing the [MUGEN dataset](https://mugen-org.github.io). The MUGEN dataset contains over 300k videos, each with corresponding audio and text, from the game CoinRun.
|
| 2 |
+
|
| 3 |
+
Before using this code,
|
| 4 |
+
|
| 5 |
+
1. Download the 3.2s-video dataset [here](https://mugen-org.github.io/download) and save as `datasets/coinrun` in your working directory.
|
| 6 |
+
* In each of `datasets/coinrun/coinrun_dataset_jsons/release/{train/val/test}.json`, change the value of `json_object["metadata"]["data_folder"]` to the absolute path of `datasets/coinrun`, e.g. `"/path/to/datasets/coinrun/"`.
|
| 7 |
+
2. Download the MUGEN dataset assets [here](https://github.com/mugen-org/MUGEN_baseline/tree/main/lib/data/coinrun/assets) and save under `datasets/coinrun` as `datasets/coinrun/assets` in your pwd.
|
| 8 |
+
* Downloading the assets from GitHub requires `git clone`-ing the original MUGEN repo and copying the assets directory located at `MUGEN_baseline/lib/data/coinrun/assets`.
|
| 9 |
+
|
| 10 |
+
Note: saving the dataset and assets to locations other than those listed above requires passing custom arguments to `MUGENDataModuleBase` or `MUGENDataset` through `MUGENDatasetArgs.data_path` and `MUGENDatasetArgs.asset_path`, respectively.
|
multimodal/examples/mugen/data/coinrun/construct_from_json.py
ADDED
|
@@ -0,0 +1,756 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import math
|
| 8 |
+
import os
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
from PIL import Image
|
| 12 |
+
|
| 13 |
+
DEATH_ANIM_LENGTH = 30
|
| 14 |
+
FINISHED_LEVEL_ANIM_LENGTH = 20
|
| 15 |
+
MONSTER_DEATH_ANIM_LENGTH = 3
|
| 16 |
+
SPACE = "."
|
| 17 |
+
LADDER = "="
|
| 18 |
+
LAVA_SURFACE = "^"
|
| 19 |
+
LAVA_MIDDLE = "|"
|
| 20 |
+
WALL_SURFACE = "S"
|
| 21 |
+
WALL_MIDDLE = "A"
|
| 22 |
+
WALL_CLIFF_LEFT = "a"
|
| 23 |
+
WALL_CLIFF_RIGHT = "b"
|
| 24 |
+
COIN_OBJ1 = "1"
|
| 25 |
+
COIN_OBJ2 = "2"
|
| 26 |
+
CRATE_NORMAL = "#"
|
| 27 |
+
CRATE_DOUBLE = "$"
|
| 28 |
+
CRATE_SINGLE = "&"
|
| 29 |
+
CRATE_WARNING = "%"
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def define_semantic_color_map(max_label=18):
|
| 33 |
+
assert max_label in [18, 21, 22], f"max_label {max_label} is not supported!"
|
| 34 |
+
|
| 35 |
+
semantic_color_map = {}
|
| 36 |
+
|
| 37 |
+
semantic_color_map["background"] = 0
|
| 38 |
+
|
| 39 |
+
# alien is always set to max_label (assumes it always appear in a video)
|
| 40 |
+
semantic_color_map["alien"] = max_label
|
| 41 |
+
|
| 42 |
+
if max_label == 18:
|
| 43 |
+
semantic_color_map["world"] = {
|
| 44 |
+
WALL_MIDDLE: 3,
|
| 45 |
+
WALL_SURFACE: 4,
|
| 46 |
+
WALL_CLIFF_LEFT: 5,
|
| 47 |
+
WALL_CLIFF_RIGHT: 6,
|
| 48 |
+
COIN_OBJ1: 17,
|
| 49 |
+
COIN_OBJ2: 0,
|
| 50 |
+
CRATE_NORMAL: 8,
|
| 51 |
+
CRATE_DOUBLE: 8,
|
| 52 |
+
CRATE_SINGLE: 8,
|
| 53 |
+
CRATE_WARNING: 8,
|
| 54 |
+
LAVA_MIDDLE: 1,
|
| 55 |
+
LAVA_SURFACE: 2,
|
| 56 |
+
LADDER: 7,
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
semantic_color_map["shield"] = 0
|
| 60 |
+
|
| 61 |
+
semantic_color_map["monster"] = {
|
| 62 |
+
"sawHalf": 16,
|
| 63 |
+
"bee": 15,
|
| 64 |
+
"slimeBlock": 14,
|
| 65 |
+
"slimeBlue": 13,
|
| 66 |
+
"mouse": 12,
|
| 67 |
+
"snail": 11,
|
| 68 |
+
"ladybug": 10,
|
| 69 |
+
"wormPink": 9,
|
| 70 |
+
"barnacle": 0,
|
| 71 |
+
"frog": 0,
|
| 72 |
+
}
|
| 73 |
+
else:
|
| 74 |
+
semantic_color_map["world"] = {
|
| 75 |
+
WALL_MIDDLE: 3,
|
| 76 |
+
WALL_SURFACE: 4,
|
| 77 |
+
WALL_CLIFF_LEFT: 5,
|
| 78 |
+
WALL_CLIFF_RIGHT: 6,
|
| 79 |
+
COIN_OBJ1: 19,
|
| 80 |
+
COIN_OBJ2: 20,
|
| 81 |
+
CRATE_NORMAL: 8,
|
| 82 |
+
CRATE_DOUBLE: 8,
|
| 83 |
+
CRATE_SINGLE: 8,
|
| 84 |
+
CRATE_WARNING: 8,
|
| 85 |
+
LAVA_MIDDLE: 1,
|
| 86 |
+
LAVA_SURFACE: 2,
|
| 87 |
+
LADDER: 7,
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
semantic_color_map["shield"] = 21
|
| 91 |
+
|
| 92 |
+
semantic_color_map["monster"] = {
|
| 93 |
+
"sawHalf": 16,
|
| 94 |
+
"bee": 15,
|
| 95 |
+
"slimeBlock": 14,
|
| 96 |
+
"slimeBlue": 13,
|
| 97 |
+
"mouse": 12,
|
| 98 |
+
"snail": 11,
|
| 99 |
+
"ladybug": 10,
|
| 100 |
+
"wormPink": 9,
|
| 101 |
+
"barnacle": 17,
|
| 102 |
+
"frog": 18,
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
return semantic_color_map
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def generate_asset_paths(game):
|
| 109 |
+
# use background corresponding with ground theme
|
| 110 |
+
bgtheme = game.background_themes[game.world_theme_n]
|
| 111 |
+
|
| 112 |
+
gtheme = game.ground_themes[game.world_theme_n]
|
| 113 |
+
walls = "kenney/Ground/" + gtheme + "/" + gtheme.lower()
|
| 114 |
+
|
| 115 |
+
# default option with fixed agent look
|
| 116 |
+
atheme = game.agent_themes[game.agent_theme_n]
|
| 117 |
+
alien = "kenneyLarge/Players/128x256_no_helmet/" + atheme + "/alien" + atheme
|
| 118 |
+
alien_paths = {"Mugen": alien}
|
| 119 |
+
|
| 120 |
+
tiles = "kenney/Tiles/"
|
| 121 |
+
items = "kenneyLarge/Items/"
|
| 122 |
+
enemy = "kenneyLarge/Enemies/"
|
| 123 |
+
|
| 124 |
+
asset_files = {}
|
| 125 |
+
|
| 126 |
+
asset_files["background"] = bgtheme
|
| 127 |
+
|
| 128 |
+
asset_files["world"] = {
|
| 129 |
+
WALL_MIDDLE: walls + "Center.png",
|
| 130 |
+
WALL_SURFACE: walls + "Mid.png",
|
| 131 |
+
WALL_CLIFF_LEFT: walls + "Cliff_left.png",
|
| 132 |
+
WALL_CLIFF_RIGHT: walls + "Cliff_right.png",
|
| 133 |
+
COIN_OBJ1: items + "coinGold.png",
|
| 134 |
+
COIN_OBJ2: items + "gemRed.png",
|
| 135 |
+
CRATE_NORMAL: tiles + "boxCrate.png",
|
| 136 |
+
CRATE_DOUBLE: tiles + "boxCrate_double.png",
|
| 137 |
+
CRATE_SINGLE: tiles + "boxCrate_single.png",
|
| 138 |
+
CRATE_WARNING: tiles + "boxCrate_warning.png",
|
| 139 |
+
LAVA_MIDDLE: tiles + "lava.png",
|
| 140 |
+
LAVA_SURFACE: tiles + "lavaTop_low.png",
|
| 141 |
+
LADDER: tiles + "ladderMid.png",
|
| 142 |
+
}
|
| 143 |
+
|
| 144 |
+
asset_files["alien"] = {}
|
| 145 |
+
for alien_name in alien_paths.keys():
|
| 146 |
+
asset_files["alien"][alien_name] = {
|
| 147 |
+
"walk1": alien_paths[alien_name] + "_walk1.png",
|
| 148 |
+
"walk2": alien_paths[alien_name] + "_walk2.png",
|
| 149 |
+
"climb1": alien_paths[alien_name] + "_climb1.png",
|
| 150 |
+
"climb2": alien_paths[alien_name] + "_climb2.png",
|
| 151 |
+
"stand": alien_paths[alien_name] + "_stand.png",
|
| 152 |
+
"jump": alien_paths[alien_name] + "_jump.png",
|
| 153 |
+
"duck": alien_paths[alien_name] + "_duck.png",
|
| 154 |
+
"hit": alien_paths[alien_name] + "_hit.png",
|
| 155 |
+
}
|
| 156 |
+
asset_files["shield"] = "bubble_shield.png"
|
| 157 |
+
|
| 158 |
+
game.flatten_monster_names()
|
| 159 |
+
# monster assets are generated based on list of names used at rendering
|
| 160 |
+
asset_files["monster"] = {
|
| 161 |
+
name: enemy + name + ".png" for name in game.flattened_monster_names
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
return asset_files
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
# binarize alpha channel if input img is in RGBA mode, set anything above 0 to 255
|
| 168 |
+
def binarize_alpha_channel(img):
|
| 169 |
+
if img.mode != "RGBA":
|
| 170 |
+
return img
|
| 171 |
+
|
| 172 |
+
w, h = img.size
|
| 173 |
+
for i in range(w):
|
| 174 |
+
for j in range(h):
|
| 175 |
+
pixel = img.getpixel((i, j))
|
| 176 |
+
|
| 177 |
+
# set alpha to 255 if alpha > 0
|
| 178 |
+
if pixel[3] > 0:
|
| 179 |
+
img.putpixel((i, j), (pixel[0], pixel[1], pixel[2], 255))
|
| 180 |
+
|
| 181 |
+
return img
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
class Asset:
|
| 185 |
+
def __init__(
|
| 186 |
+
self,
|
| 187 |
+
name,
|
| 188 |
+
file,
|
| 189 |
+
asset_root,
|
| 190 |
+
kind="world",
|
| 191 |
+
kx=80,
|
| 192 |
+
ky=80,
|
| 193 |
+
semantic_color=(0, 0, 0),
|
| 194 |
+
flip=False,
|
| 195 |
+
binarize_alpha=False,
|
| 196 |
+
):
|
| 197 |
+
self.name = name
|
| 198 |
+
self.file = file
|
| 199 |
+
self.asset_root = asset_root
|
| 200 |
+
self.kind = kind
|
| 201 |
+
self.kx = kx
|
| 202 |
+
self.ky = ky
|
| 203 |
+
self.semantic_color = semantic_color
|
| 204 |
+
self.flip = flip
|
| 205 |
+
self.binarize_alpha = binarize_alpha
|
| 206 |
+
|
| 207 |
+
self.load_asset()
|
| 208 |
+
|
| 209 |
+
def load_asset(self):
|
| 210 |
+
asset_path = os.path.join(self.asset_root, self.file)
|
| 211 |
+
if not os.path.isfile(asset_path):
|
| 212 |
+
# basically remove the '_walk1' postfix
|
| 213 |
+
fallback_path = (
|
| 214 |
+
"_".join(asset_path.split("_")[:-1]) + "." + asset_path.split(".")[-1]
|
| 215 |
+
)
|
| 216 |
+
assert os.path.isfile(fallback_path), asset_path
|
| 217 |
+
asset_path = fallback_path
|
| 218 |
+
self.asset = Image.open(asset_path)
|
| 219 |
+
|
| 220 |
+
# used for (user control) asset swap, because alien h:w == 2:1 while others is 1:1
|
| 221 |
+
# the asset resize at loading and render grid size all need to change respectively
|
| 222 |
+
self.aspect_ratio = self.asset.size[1] / self.asset.size[0]
|
| 223 |
+
|
| 224 |
+
if self.kind == "world":
|
| 225 |
+
if self.name != LAVA_MIDDLE and self.name != LAVA_SURFACE:
|
| 226 |
+
# LAVA has a special way of rendering animation so don't resize now
|
| 227 |
+
self.asset = self.asset.resize(
|
| 228 |
+
(math.ceil(self.kx + 0.5), math.ceil(self.ky + 0.5))
|
| 229 |
+
)
|
| 230 |
+
elif self.kind == "alien":
|
| 231 |
+
self.asset = self.asset.resize(
|
| 232 |
+
(math.ceil(self.kx), math.ceil(self.aspect_ratio * self.ky))
|
| 233 |
+
)
|
| 234 |
+
elif self.kind == "shield":
|
| 235 |
+
self.asset = self.asset.resize(
|
| 236 |
+
(math.ceil(self.kx * 1.15), math.ceil(self.ky * 2.1))
|
| 237 |
+
)
|
| 238 |
+
elif self.kind == "monster" or self.kind == "background":
|
| 239 |
+
self.asset = self.asset.resize((math.ceil(self.kx), math.ceil(self.ky)))
|
| 240 |
+
else:
|
| 241 |
+
raise NotImplementedError(f"Unknown asset kind {self.kind}")
|
| 242 |
+
|
| 243 |
+
# flip if needed (for facing left/right)
|
| 244 |
+
if self.flip:
|
| 245 |
+
self.asset = self.asset.transpose(Image.FLIP_LEFT_RIGHT)
|
| 246 |
+
|
| 247 |
+
if self.binarize_alpha:
|
| 248 |
+
self.asset = binarize_alpha_channel(self.asset)
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
def load_assets(
|
| 252 |
+
asset_files, asset_root, semantic_color_map, kx=80, ky=80, gen_original=False
|
| 253 |
+
):
|
| 254 |
+
asset_map = {}
|
| 255 |
+
|
| 256 |
+
for kind in asset_files.keys():
|
| 257 |
+
assert kind in semantic_color_map
|
| 258 |
+
|
| 259 |
+
if kind == "background":
|
| 260 |
+
# background will be loaded separately
|
| 261 |
+
continue
|
| 262 |
+
|
| 263 |
+
if kind == "shield":
|
| 264 |
+
# asset file for the bubble shield in agent power-up mode
|
| 265 |
+
asset_map[kind] = Asset(
|
| 266 |
+
name=kind,
|
| 267 |
+
file=asset_files[kind],
|
| 268 |
+
asset_root=asset_root,
|
| 269 |
+
kind=kind,
|
| 270 |
+
kx=kx,
|
| 271 |
+
ky=ky,
|
| 272 |
+
semantic_color=semantic_color_map[kind],
|
| 273 |
+
binarize_alpha=not gen_original,
|
| 274 |
+
)
|
| 275 |
+
continue
|
| 276 |
+
|
| 277 |
+
for key in asset_files[kind].keys():
|
| 278 |
+
if kind == "world":
|
| 279 |
+
# ground asset, no need to worry about pose or facing
|
| 280 |
+
asset_map[key] = Asset(
|
| 281 |
+
name=key,
|
| 282 |
+
file=asset_files[kind][key],
|
| 283 |
+
asset_root=asset_root,
|
| 284 |
+
kind=kind,
|
| 285 |
+
kx=kx,
|
| 286 |
+
ky=ky,
|
| 287 |
+
semantic_color=semantic_color_map[kind][key],
|
| 288 |
+
binarize_alpha=not gen_original,
|
| 289 |
+
)
|
| 290 |
+
elif kind == "alien":
|
| 291 |
+
for pose in asset_files[kind][key].keys():
|
| 292 |
+
# facing right is default to empty
|
| 293 |
+
all_facings = ["", "_left"]
|
| 294 |
+
for facing in all_facings:
|
| 295 |
+
a_key = key + "_" + pose + facing
|
| 296 |
+
|
| 297 |
+
asset_map[a_key] = Asset(
|
| 298 |
+
name=a_key,
|
| 299 |
+
file=asset_files[kind][key][pose],
|
| 300 |
+
asset_root=asset_root,
|
| 301 |
+
kind=kind,
|
| 302 |
+
kx=kx,
|
| 303 |
+
ky=ky,
|
| 304 |
+
semantic_color=semantic_color_map[kind],
|
| 305 |
+
flip=(facing != ""), # flip the asset if facing is not ''
|
| 306 |
+
binarize_alpha=not gen_original,
|
| 307 |
+
)
|
| 308 |
+
elif kind == "monster":
|
| 309 |
+
# for monsters, 3 types of assets will be loaded
|
| 310 |
+
# for each of them, facing can be left or right
|
| 311 |
+
all_poses = ["", "_move", "_dead"] # walk1 is default to empty
|
| 312 |
+
all_facings = ["", "_right"] # facing left is default to empty
|
| 313 |
+
base_fn = os.path.splitext(asset_files[kind][key])[
|
| 314 |
+
0
|
| 315 |
+
] # e.g. Enemies/bee
|
| 316 |
+
for pose in all_poses:
|
| 317 |
+
for facing in all_facings:
|
| 318 |
+
m_key = key + pose + facing
|
| 319 |
+
file_name = base_fn + pose + ".png"
|
| 320 |
+
|
| 321 |
+
asset_map[m_key] = Asset(
|
| 322 |
+
name=m_key,
|
| 323 |
+
file=file_name,
|
| 324 |
+
asset_root=asset_root,
|
| 325 |
+
kind="monster",
|
| 326 |
+
kx=kx,
|
| 327 |
+
ky=ky,
|
| 328 |
+
semantic_color=semantic_color_map[kind][key],
|
| 329 |
+
flip=(facing != ""), # flip the asset if facing is not ''
|
| 330 |
+
binarize_alpha=not gen_original,
|
| 331 |
+
)
|
| 332 |
+
else:
|
| 333 |
+
raise NotImplementedError(f"Unknown asset kind {kind}")
|
| 334 |
+
|
| 335 |
+
return asset_map
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
# load background asset, zoom is different so need a separate function
|
| 339 |
+
def load_bg_asset(asset_files, asset_root, semantic_color_map, zx, zy):
|
| 340 |
+
kind = "background"
|
| 341 |
+
bg_asset = Asset(
|
| 342 |
+
name=kind,
|
| 343 |
+
file=asset_files[kind],
|
| 344 |
+
asset_root=asset_root,
|
| 345 |
+
kind=kind,
|
| 346 |
+
kx=zx,
|
| 347 |
+
ky=zy,
|
| 348 |
+
semantic_color=semantic_color_map[kind],
|
| 349 |
+
)
|
| 350 |
+
return bg_asset
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
# used for alien dying animation in gen_original mode
|
| 354 |
+
def get_transparent_asset(input_asset, transparency):
|
| 355 |
+
assert input_asset.mode == "RGBA"
|
| 356 |
+
np_asset = np.array(input_asset, dtype=np.int16)
|
| 357 |
+
np_asset[:, :, 3] -= transparency
|
| 358 |
+
np_asset[:, :, 3] = np.clip(np_asset[:, :, 3], 0, None)
|
| 359 |
+
return Image.fromarray(np_asset.astype(np.uint8))
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
# return rect in integer values, floor for x1,y1, ceil for x2,y2 or w,h
|
| 363 |
+
def integer_rect(rect):
|
| 364 |
+
return [
|
| 365 |
+
math.floor(rect[0]),
|
| 366 |
+
math.floor(rect[1]),
|
| 367 |
+
math.ceil(rect[2]),
|
| 368 |
+
math.ceil(rect[3]),
|
| 369 |
+
]
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
def convert_xywh_to_xyxy(rect):
|
| 373 |
+
return [rect[0], rect[1], rect[0] + rect[2], rect[1] + rect[3]]
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
def convert_xyxy_to_xywh(rect):
|
| 377 |
+
return [rect[0], rect[1], rect[2] - rect[0], rect[3] - rect[1]]
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
# rect format is xywh, img_size is (w,h)
|
| 381 |
+
def check_out_of_bounds(rect, img_size):
|
| 382 |
+
if rect[0] + rect[2] < 0:
|
| 383 |
+
return True
|
| 384 |
+
if rect[0] > img_size[0]:
|
| 385 |
+
return True
|
| 386 |
+
if rect[1] + rect[3] < 0:
|
| 387 |
+
return True
|
| 388 |
+
if rect[1] > img_size[1]:
|
| 389 |
+
return True
|
| 390 |
+
return False
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
# return intersect of two rects, input and output are both in xywh format
|
| 394 |
+
def intersect_rects(rect1, rect2):
|
| 395 |
+
xyxy_rect1 = convert_xywh_to_xyxy(rect1)
|
| 396 |
+
xyxy_rect2 = convert_xywh_to_xyxy(rect2)
|
| 397 |
+
xyxy_res_rect = [
|
| 398 |
+
max(xyxy_rect1[0], xyxy_rect2[0]),
|
| 399 |
+
max(xyxy_rect1[1], xyxy_rect2[1]),
|
| 400 |
+
min(xyxy_rect1[2], xyxy_rect2[2]),
|
| 401 |
+
min(xyxy_rect1[3], xyxy_rect2[3]),
|
| 402 |
+
]
|
| 403 |
+
|
| 404 |
+
xywh_res_rect = convert_xyxy_to_xywh(xyxy_res_rect)
|
| 405 |
+
|
| 406 |
+
# check if the intersection is empty
|
| 407 |
+
if xywh_res_rect[2] > 0 and xywh_res_rect[3] > 0:
|
| 408 |
+
return xywh_res_rect
|
| 409 |
+
else:
|
| 410 |
+
return None
|
| 411 |
+
|
| 412 |
+
|
| 413 |
+
# rect is in the format of xywh
|
| 414 |
+
def paint_color_in_rect_with_mask(
|
| 415 |
+
img, rect, color, mask, gen_original=False, ignore_mask=False, cut_mask_top_ratio=0
|
| 416 |
+
):
|
| 417 |
+
w, h = mask.size
|
| 418 |
+
img_w, img_h = img.size
|
| 419 |
+
# in some cases, mask size doesn't match the rect (e.g. monster dying)
|
| 420 |
+
if rect[2] != w or rect[3] != h:
|
| 421 |
+
if not gen_original:
|
| 422 |
+
mask = mask.resize((rect[2], rect[3]), resample=Image.NEAREST)
|
| 423 |
+
else:
|
| 424 |
+
mask = mask.resize((rect[2], rect[3]))
|
| 425 |
+
w, h = mask.size
|
| 426 |
+
|
| 427 |
+
if not gen_original:
|
| 428 |
+
# generate semantic map
|
| 429 |
+
if ignore_mask and cut_mask_top_ratio != 0:
|
| 430 |
+
# specifically for agent because its asset has a large empty area in the top,
|
| 431 |
+
# we don't want it to be fully masked
|
| 432 |
+
if cut_mask_top_ratio < 0:
|
| 433 |
+
# automatic calculate the first non-empty row from top
|
| 434 |
+
np_mask = np.array(mask)
|
| 435 |
+
cut_mask_top_rows = (np_mask.T[0].sum(axis=0) != 0).argmax(axis=0)
|
| 436 |
+
else:
|
| 437 |
+
cut_mask_top_rows = int(cut_mask_top_ratio * rect[2])
|
| 438 |
+
rect[1] += cut_mask_top_rows
|
| 439 |
+
rect[3] = mask.size[1] - cut_mask_top_rows
|
| 440 |
+
|
| 441 |
+
img = img.paste(color, convert_xywh_to_xyxy(rect))
|
| 442 |
+
else:
|
| 443 |
+
# paste in single color if generating semantic maps (so not original)
|
| 444 |
+
# if ignore_mask, this will generate a complete block mask same as rect
|
| 445 |
+
img = img.paste(
|
| 446 |
+
color,
|
| 447 |
+
convert_xywh_to_xyxy(rect),
|
| 448 |
+
mask if (mask.mode == "RGBA" and not ignore_mask) else None,
|
| 449 |
+
)
|
| 450 |
+
else:
|
| 451 |
+
# generate rgb data
|
| 452 |
+
img = img.paste(
|
| 453 |
+
mask, convert_xywh_to_xyxy(rect), mask if mask.mode == "RGBA" else None
|
| 454 |
+
)
|
| 455 |
+
|
| 456 |
+
return
|
| 457 |
+
|
| 458 |
+
|
| 459 |
+
def draw_game_frame(
|
| 460 |
+
game,
|
| 461 |
+
frame_id,
|
| 462 |
+
asset_map,
|
| 463 |
+
kx,
|
| 464 |
+
ky,
|
| 465 |
+
gen_original=False,
|
| 466 |
+
bbox_smap_for_agent=False,
|
| 467 |
+
bbox_smap_for_monsters=False,
|
| 468 |
+
alien_name=None,
|
| 469 |
+
skip_foreground=False,
|
| 470 |
+
skip_background=False,
|
| 471 |
+
skip_mugen=False,
|
| 472 |
+
only_mugen=False,
|
| 473 |
+
):
|
| 474 |
+
# set default alien name/key
|
| 475 |
+
if alien_name is None:
|
| 476 |
+
alien_name = "Mugen"
|
| 477 |
+
|
| 478 |
+
# initialize an empty image (all zero, for background)
|
| 479 |
+
if not gen_original:
|
| 480 |
+
img = Image.new("L", (game.video_res, game.video_res))
|
| 481 |
+
else:
|
| 482 |
+
img = Image.new("RGB", (game.video_res, game.video_res))
|
| 483 |
+
|
| 484 |
+
video_center = (game.video_res - 1) // 2
|
| 485 |
+
|
| 486 |
+
frame = game.frames[frame_id]
|
| 487 |
+
|
| 488 |
+
# for agent-centric
|
| 489 |
+
# dx = -frame.agent.x * kx + video_center - 0.5 * kx
|
| 490 |
+
# dy = frame.agent.y * ky - video_center - 0.5 * ky
|
| 491 |
+
# for video data (no vertical camera move)
|
| 492 |
+
dx = -frame.agent.x * kx + video_center - 0.5 * kx
|
| 493 |
+
|
| 494 |
+
# different dy/ky ratio based on zoom level, to adjust camera view
|
| 495 |
+
if game.zoom == 5.5:
|
| 496 |
+
dy_ratio = 5.0
|
| 497 |
+
elif game.zoom == 4.3:
|
| 498 |
+
dy_ratio = 6.5
|
| 499 |
+
elif game.zoom == 5.0:
|
| 500 |
+
dy_ratio = 5.5
|
| 501 |
+
elif game.zoom == 6.0:
|
| 502 |
+
dy_ratio = 4.5
|
| 503 |
+
else:
|
| 504 |
+
raise NotImplementedError(f"zoom level {game.zoom} is not supported!")
|
| 505 |
+
dy = -video_center + dy_ratio * ky
|
| 506 |
+
|
| 507 |
+
# update background image with proper zoom for gen_original mode
|
| 508 |
+
# NOTE: if desired background label is not zero, set it here to asset_map['background'].semantic_color
|
| 509 |
+
if gen_original and not skip_background and not only_mugen:
|
| 510 |
+
zx = game.video_res * game.zoom
|
| 511 |
+
zy = zx
|
| 512 |
+
for tile_x in range(-1, 3):
|
| 513 |
+
for tile_y in range(-1, 2):
|
| 514 |
+
bg_rect = [0, 0, zx, zy]
|
| 515 |
+
bg_rect[0] = (
|
| 516 |
+
zx * tile_x
|
| 517 |
+
+ video_center
|
| 518 |
+
+ game.bgzoom * (dx + kx * game.maze_h / 2)
|
| 519 |
+
- zx * 0.5
|
| 520 |
+
)
|
| 521 |
+
bg_rect[1] = (
|
| 522 |
+
zy * tile_y
|
| 523 |
+
+ video_center
|
| 524 |
+
+ game.bgzoom * (dy - ky * game.maze_h / 2)
|
| 525 |
+
- zy * 0.5
|
| 526 |
+
)
|
| 527 |
+
if check_out_of_bounds(bg_rect, img.size):
|
| 528 |
+
continue
|
| 529 |
+
img.paste(
|
| 530 |
+
asset_map["background"].asset,
|
| 531 |
+
convert_xywh_to_xyxy(integer_rect(bg_rect)),
|
| 532 |
+
)
|
| 533 |
+
|
| 534 |
+
# NOTE: game engine now hard-code 64 for maze_size
|
| 535 |
+
radius = int(1 + game.maze_w / game.zoom)
|
| 536 |
+
ix = int(frame.agent.x + 0.5)
|
| 537 |
+
iy = int(frame.agent.y + 0.5)
|
| 538 |
+
x_start = max(ix - radius, 0)
|
| 539 |
+
x_end = min(ix + radius + 1, game.maze_w)
|
| 540 |
+
y_start = max(iy - radius, 0)
|
| 541 |
+
y_end = min(iy + radius + 1, game.maze_h)
|
| 542 |
+
win_h = game.video_res
|
| 543 |
+
|
| 544 |
+
# convert eaten coins to a set for faster checking coordinates
|
| 545 |
+
coins_eaten_set = {tuple(coin_coord) for coin_coord in frame.coins_eaten}
|
| 546 |
+
|
| 547 |
+
if not skip_background and not only_mugen:
|
| 548 |
+
for y in range(y_start, y_end):
|
| 549 |
+
for x in range(x_start, x_end):
|
| 550 |
+
wkey = game.maze[y][x]
|
| 551 |
+
if wkey == SPACE:
|
| 552 |
+
continue
|
| 553 |
+
|
| 554 |
+
# eaten coins is treated the same as SPACE, just continue
|
| 555 |
+
# but we should not modify the coins in maze to SPACE, or it may cause inconsistency
|
| 556 |
+
# if we ever need to render backwards or save json after drawing
|
| 557 |
+
if (x, y) in coins_eaten_set:
|
| 558 |
+
continue
|
| 559 |
+
|
| 560 |
+
assert wkey in asset_map, f"{wkey} not in assets!"
|
| 561 |
+
|
| 562 |
+
tile_rect = [
|
| 563 |
+
kx * x + dx - 0.1,
|
| 564 |
+
win_h - ky * y + dy - 0.1,
|
| 565 |
+
kx + 0.5 + 0.2,
|
| 566 |
+
ky + 0.5 + 0.2,
|
| 567 |
+
]
|
| 568 |
+
|
| 569 |
+
# skip tile if the rect is completely out-of-bounds
|
| 570 |
+
if check_out_of_bounds(tile_rect, img.size):
|
| 571 |
+
continue
|
| 572 |
+
|
| 573 |
+
if wkey == LAVA_MIDDLE or wkey == LAVA_SURFACE:
|
| 574 |
+
d1 = tile_rect[:]
|
| 575 |
+
d2 = tile_rect[:]
|
| 576 |
+
asset_size = asset_map[wkey].asset.size
|
| 577 |
+
sr = [0, 0, asset_size[0], asset_size[1]]
|
| 578 |
+
sr1 = sr[:]
|
| 579 |
+
sr2 = sr[:]
|
| 580 |
+
tr = frame.state_time * 0.1
|
| 581 |
+
tr -= int(tr)
|
| 582 |
+
tr *= -1
|
| 583 |
+
d1[0] += tr * tile_rect[2]
|
| 584 |
+
d2[0] += tile_rect[2] + tr * tile_rect[2]
|
| 585 |
+
sr1[0] += -tr * asset_size[0]
|
| 586 |
+
sr2[0] += -asset_size[0] - tr * asset_size[0]
|
| 587 |
+
d1 = intersect_rects(d1, tile_rect)
|
| 588 |
+
d2 = intersect_rects(d2, tile_rect)
|
| 589 |
+
if d1 is not None:
|
| 590 |
+
d1[2] += 0.5
|
| 591 |
+
if d2 is not None:
|
| 592 |
+
d2[0] -= 0.5
|
| 593 |
+
d2[2] += 0.5
|
| 594 |
+
sr1 = intersect_rects(sr1, sr)
|
| 595 |
+
sr2 = intersect_rects(sr2, sr)
|
| 596 |
+
if sr1 is not None and d1 is not None:
|
| 597 |
+
# crop and render one half of the asset
|
| 598 |
+
crop_mask = asset_map[wkey].asset.crop(
|
| 599 |
+
integer_rect(convert_xywh_to_xyxy(sr1))
|
| 600 |
+
)
|
| 601 |
+
paint_color_in_rect_with_mask(
|
| 602 |
+
img,
|
| 603 |
+
integer_rect(d1),
|
| 604 |
+
asset_map[wkey].semantic_color,
|
| 605 |
+
crop_mask,
|
| 606 |
+
gen_original=gen_original,
|
| 607 |
+
)
|
| 608 |
+
if sr2 is not None and d2 is not None:
|
| 609 |
+
# crop and render the other half of the asset (swapped places horizontally)
|
| 610 |
+
crop_mask = asset_map[wkey].asset.crop(
|
| 611 |
+
integer_rect(convert_xywh_to_xyxy(sr2))
|
| 612 |
+
)
|
| 613 |
+
paint_color_in_rect_with_mask(
|
| 614 |
+
img,
|
| 615 |
+
integer_rect(d2),
|
| 616 |
+
asset_map[wkey].semantic_color,
|
| 617 |
+
crop_mask,
|
| 618 |
+
gen_original=gen_original,
|
| 619 |
+
)
|
| 620 |
+
else:
|
| 621 |
+
paint_color_in_rect_with_mask(
|
| 622 |
+
img,
|
| 623 |
+
integer_rect(tile_rect),
|
| 624 |
+
asset_map[wkey].semantic_color,
|
| 625 |
+
asset_map[wkey].asset,
|
| 626 |
+
gen_original=gen_original,
|
| 627 |
+
)
|
| 628 |
+
|
| 629 |
+
if not skip_foreground:
|
| 630 |
+
if not only_mugen:
|
| 631 |
+
# paint monsters
|
| 632 |
+
for mi in range(len(frame.monsters)):
|
| 633 |
+
if frame.monsters[mi].is_dead:
|
| 634 |
+
dying_frame_cnt = max(0, frame.monsters[mi].monster_dying_frame_cnt)
|
| 635 |
+
monster_shrinkage = (
|
| 636 |
+
(MONSTER_DEATH_ANIM_LENGTH - dying_frame_cnt)
|
| 637 |
+
* 0.8
|
| 638 |
+
/ MONSTER_DEATH_ANIM_LENGTH
|
| 639 |
+
)
|
| 640 |
+
monster_rect = [
|
| 641 |
+
math.floor(kx * frame.monsters[mi].x + dx),
|
| 642 |
+
math.floor(
|
| 643 |
+
win_h
|
| 644 |
+
- ky * frame.monsters[mi].y
|
| 645 |
+
+ dy
|
| 646 |
+
+ ky * monster_shrinkage
|
| 647 |
+
),
|
| 648 |
+
math.ceil(kx),
|
| 649 |
+
math.ceil(ky * (1 - monster_shrinkage)),
|
| 650 |
+
]
|
| 651 |
+
else:
|
| 652 |
+
monster_rect = [
|
| 653 |
+
math.floor(kx * frame.monsters[mi].x + dx),
|
| 654 |
+
math.floor(win_h - ky * frame.monsters[mi].y + dy),
|
| 655 |
+
math.ceil(kx),
|
| 656 |
+
math.ceil(ky),
|
| 657 |
+
]
|
| 658 |
+
|
| 659 |
+
m_name = game.flattened_monster_names[frame.monsters[mi].theme]
|
| 660 |
+
# add pose and facing to the key to find correct asset
|
| 661 |
+
m_pose = "" if frame.monsters[mi].walk1_mode else "_move"
|
| 662 |
+
if frame.monsters[mi].is_dead:
|
| 663 |
+
m_pose = "_dead"
|
| 664 |
+
m_key = (
|
| 665 |
+
m_name + m_pose + ("_right" if frame.monsters[mi].vx > 0 else "")
|
| 666 |
+
)
|
| 667 |
+
|
| 668 |
+
paint_color_in_rect_with_mask(
|
| 669 |
+
img,
|
| 670 |
+
monster_rect,
|
| 671 |
+
asset_map[m_key].semantic_color,
|
| 672 |
+
asset_map[m_key].asset,
|
| 673 |
+
gen_original=gen_original,
|
| 674 |
+
ignore_mask=bbox_smap_for_monsters,
|
| 675 |
+
)
|
| 676 |
+
|
| 677 |
+
if not skip_mugen:
|
| 678 |
+
# paint agent - do it after monsters so agent is always in front
|
| 679 |
+
a_key = (
|
| 680 |
+
alien_name
|
| 681 |
+
+ "_"
|
| 682 |
+
+ frame.agent.pose
|
| 683 |
+
+ ("" if frame.agent.is_facing_right else "_left")
|
| 684 |
+
)
|
| 685 |
+
# note how aspect_ratio is used for alien rect, this can be applied to
|
| 686 |
+
# monster rect to support asset that's not 1:1 (e.g. use alien as monster)
|
| 687 |
+
alien_rect = [
|
| 688 |
+
math.floor(kx * frame.agent.x + dx),
|
| 689 |
+
# math.floor(win_h - ky * (frame.agent.y + 1) + dy), # default for 2:1 alien, no asset swap
|
| 690 |
+
math.floor(
|
| 691 |
+
win_h
|
| 692 |
+
- ky * (frame.agent.y + asset_map[a_key].aspect_ratio - 1)
|
| 693 |
+
+ dy
|
| 694 |
+
),
|
| 695 |
+
math.ceil(kx),
|
| 696 |
+
# math.ceil(2 * ky), # default for 2:1 alien, no asset swap
|
| 697 |
+
math.ceil(asset_map[a_key].aspect_ratio * ky),
|
| 698 |
+
]
|
| 699 |
+
if frame.agent.is_killed:
|
| 700 |
+
transparency = (
|
| 701 |
+
DEATH_ANIM_LENGTH + 1 - frame.agent.killed_animation_frame_cnt
|
| 702 |
+
) * 12
|
| 703 |
+
# only render if not fully transparent
|
| 704 |
+
if transparency > 255:
|
| 705 |
+
agent_asset = None
|
| 706 |
+
else:
|
| 707 |
+
if gen_original:
|
| 708 |
+
agent_asset = get_transparent_asset(
|
| 709 |
+
asset_map[a_key].asset, transparency
|
| 710 |
+
)
|
| 711 |
+
else:
|
| 712 |
+
# when generating semantic map, alien mask won't change unless fully transparent
|
| 713 |
+
agent_asset = asset_map[a_key].asset
|
| 714 |
+
else:
|
| 715 |
+
agent_asset = asset_map[a_key].asset
|
| 716 |
+
if agent_asset is not None:
|
| 717 |
+
paint_color_in_rect_with_mask(
|
| 718 |
+
img,
|
| 719 |
+
alien_rect,
|
| 720 |
+
asset_map[a_key].semantic_color,
|
| 721 |
+
agent_asset,
|
| 722 |
+
gen_original=gen_original,
|
| 723 |
+
ignore_mask=bbox_smap_for_agent,
|
| 724 |
+
cut_mask_top_ratio=0.8,
|
| 725 |
+
)
|
| 726 |
+
|
| 727 |
+
# paint the bubble shield if agent is in power-up mode
|
| 728 |
+
if frame.agent.power_up_mode:
|
| 729 |
+
shield_rect = [
|
| 730 |
+
# NOTE: game engine hard-codes 7 and 8 for co-ordinates which won't work with video-res that's not 1024
|
| 731 |
+
# (for training we usually generate with 256 or 128 video_res), so need to convert them
|
| 732 |
+
math.floor(kx * frame.agent.x + dx - 7 * game.video_res / 1024),
|
| 733 |
+
math.floor(
|
| 734 |
+
win_h
|
| 735 |
+
- ky * (frame.agent.y + 1)
|
| 736 |
+
+ dy
|
| 737 |
+
+ 8 * game.video_res / 1024
|
| 738 |
+
),
|
| 739 |
+
math.ceil(kx * 1.15),
|
| 740 |
+
math.ceil(ky * 2.1),
|
| 741 |
+
]
|
| 742 |
+
# pull bubble down when Mugen crouches
|
| 743 |
+
if frame.agent.pose == "duck":
|
| 744 |
+
shield_rect[1] += math.floor(8 * game.video_res / 1024)
|
| 745 |
+
|
| 746 |
+
paint_color_in_rect_with_mask(
|
| 747 |
+
img,
|
| 748 |
+
shield_rect,
|
| 749 |
+
asset_map["shield"].semantic_color,
|
| 750 |
+
asset_map["shield"].asset,
|
| 751 |
+
gen_original=gen_original,
|
| 752 |
+
ignore_mask=bbox_smap_for_agent,
|
| 753 |
+
cut_mask_top_ratio=0.45,
|
| 754 |
+
)
|
| 755 |
+
|
| 756 |
+
return img
|
multimodal/examples/mugen/data/coinrun/game.py
ADDED
|
@@ -0,0 +1,295 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import json
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class Game:
|
| 11 |
+
def __init__(self, **kwargs):
|
| 12 |
+
self.game_id = -1
|
| 13 |
+
self.level_seed = 0
|
| 14 |
+
self.rl_agent_seed = 0
|
| 15 |
+
self.zoom = 5.5
|
| 16 |
+
self.bgzoom = 0.4 # NOTE: hard-coded
|
| 17 |
+
self.world_theme_n = -1
|
| 18 |
+
self.agent_theme_n = -1
|
| 19 |
+
|
| 20 |
+
self.background_themes = []
|
| 21 |
+
self.ground_themes = []
|
| 22 |
+
self.agent_themes = []
|
| 23 |
+
self.monster_names = {}
|
| 24 |
+
self.flattened_monster_names = []
|
| 25 |
+
|
| 26 |
+
# TODO: save and load these from the game engine
|
| 27 |
+
self.video_res = 1024
|
| 28 |
+
self.maze_w = 64
|
| 29 |
+
self.maze_h = 13 # for zoom 5.5
|
| 30 |
+
|
| 31 |
+
self.reset_game()
|
| 32 |
+
|
| 33 |
+
self.__dict__.update(**kwargs)
|
| 34 |
+
self.frames = [Frame(**f) for f in self.frames]
|
| 35 |
+
|
| 36 |
+
def reset_game(self):
|
| 37 |
+
self.maze = None
|
| 38 |
+
self.frames = []
|
| 39 |
+
|
| 40 |
+
def asdict(self, f_start=-1, f_end=-1):
|
| 41 |
+
if f_end < 0:
|
| 42 |
+
# show all frames by default
|
| 43 |
+
frames_as_dict = [f.asdict() for f in self.frames]
|
| 44 |
+
else:
|
| 45 |
+
frames_as_dict = [f.asdict() for f in self.frames[f_start:f_end]]
|
| 46 |
+
return {
|
| 47 |
+
"game_id": self.game_id,
|
| 48 |
+
"level_seed": self.level_seed,
|
| 49 |
+
"rl_agent_seed": self.rl_agent_seed,
|
| 50 |
+
"zoom": self.zoom,
|
| 51 |
+
"bgzoom": self.bgzoom,
|
| 52 |
+
"world_theme_n": self.world_theme_n,
|
| 53 |
+
"agent_theme_n": self.agent_theme_n,
|
| 54 |
+
"background_themes": self.background_themes,
|
| 55 |
+
"ground_themes": self.ground_themes,
|
| 56 |
+
"agent_themes": self.agent_themes,
|
| 57 |
+
"monster_names": self.monster_names,
|
| 58 |
+
"video_res": self.video_res,
|
| 59 |
+
"maze_w": self.maze_w,
|
| 60 |
+
"maze_h": self.maze_h,
|
| 61 |
+
"maze": self.maze if self.maze is not None else None,
|
| 62 |
+
"frames": frames_as_dict,
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
def __repr__(self):
|
| 66 |
+
return json.dumps(self.asdict())
|
| 67 |
+
|
| 68 |
+
def save_json(self, json_path, f_start=-1, f_end=-1):
|
| 69 |
+
with open(json_path, "w") as f:
|
| 70 |
+
json.dump(self.asdict(f_start, f_end), f, indent=2)
|
| 71 |
+
|
| 72 |
+
def load_json(self, json_path):
|
| 73 |
+
with open(json_path, "r") as f:
|
| 74 |
+
data = json.load(f)
|
| 75 |
+
|
| 76 |
+
self.reset_game()
|
| 77 |
+
self.__dict__.update(**data)
|
| 78 |
+
self.frames = [Frame(**f) for f in self.frames]
|
| 79 |
+
|
| 80 |
+
self.flatten_monster_names()
|
| 81 |
+
self.reset_eaten_coins()
|
| 82 |
+
|
| 83 |
+
def flatten_monster_names(self):
|
| 84 |
+
# the order is important!
|
| 85 |
+
self.flattened_monster_names = self.monster_names["ground"]
|
| 86 |
+
self.flattened_monster_names.extend(self.monster_names["walking"])
|
| 87 |
+
self.flattened_monster_names.extend(self.monster_names["flying"])
|
| 88 |
+
|
| 89 |
+
# NOTE: some coins might be missing due to how 3s clip json is saved
|
| 90 |
+
# reset all eaten coins to put them back
|
| 91 |
+
# this is a temporary fix until we regenerate all jsons
|
| 92 |
+
def reset_eaten_coins(self):
|
| 93 |
+
for coin_loc in self.frames[-1].coins_eaten:
|
| 94 |
+
# note the game rows are saved as strings
|
| 95 |
+
# NOTE: '1' is the yellow coin, we also has another type '2' that is the red gem
|
| 96 |
+
# but the json with '2' enabled should not have this issue
|
| 97 |
+
if self.maze[coin_loc[1]][coin_loc[0]] == ".":
|
| 98 |
+
self.maze[coin_loc[1]] = (
|
| 99 |
+
self.maze[coin_loc[1]][: coin_loc[0]]
|
| 100 |
+
+ "1"
|
| 101 |
+
+ self.maze[coin_loc[1]][(coin_loc[0] + 1) :]
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
class Frame:
|
| 106 |
+
def __init__(self, **kwargs):
|
| 107 |
+
self.frame_id = -1
|
| 108 |
+
self.file_name = ""
|
| 109 |
+
self.state_time = 0
|
| 110 |
+
self.coins_eaten = []
|
| 111 |
+
self.agent = None
|
| 112 |
+
self.monsters = []
|
| 113 |
+
|
| 114 |
+
self.__dict__.update(**kwargs)
|
| 115 |
+
if "agent" in self.__dict__ and self.agent is not None:
|
| 116 |
+
self.agent = Agent(**self.agent)
|
| 117 |
+
if "monsters" in self.__dict__:
|
| 118 |
+
self.monsters = [Monster(**m) for m in self.monsters]
|
| 119 |
+
|
| 120 |
+
def asdict(self):
|
| 121 |
+
return {
|
| 122 |
+
"frame_id": self.frame_id,
|
| 123 |
+
"file_name": self.file_name,
|
| 124 |
+
"state_time": self.state_time,
|
| 125 |
+
"coins_eaten": self.coins_eaten,
|
| 126 |
+
"agent": self.agent.asdict() if self.agent is not None else None,
|
| 127 |
+
"monsters": [m.asdict() for m in self.monsters],
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
def __repr__(self):
|
| 131 |
+
return json.dumps(self.asdict())
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
class Agent:
|
| 135 |
+
def __init__(
|
| 136 |
+
self,
|
| 137 |
+
x,
|
| 138 |
+
y,
|
| 139 |
+
vx=0.0,
|
| 140 |
+
vy=0.0,
|
| 141 |
+
time_alive=0,
|
| 142 |
+
ladder=False,
|
| 143 |
+
spring=0,
|
| 144 |
+
is_killed=False,
|
| 145 |
+
killed_animation_frame_cnt=0,
|
| 146 |
+
finished_level_frame_cnt=0,
|
| 147 |
+
killed_monster=False,
|
| 148 |
+
bumped_head=False,
|
| 149 |
+
collected_coin=False,
|
| 150 |
+
collected_gem=False,
|
| 151 |
+
power_up_mode=False,
|
| 152 |
+
**kwargs,
|
| 153 |
+
):
|
| 154 |
+
self.x = x
|
| 155 |
+
self.y = y
|
| 156 |
+
self.vx = vx
|
| 157 |
+
self.vy = vy
|
| 158 |
+
self.time_alive = time_alive
|
| 159 |
+
self.ladder = ladder # for climb pose
|
| 160 |
+
self.spring = spring # for duck pose
|
| 161 |
+
|
| 162 |
+
# states related to agent dying or finishing animations
|
| 163 |
+
self.is_killed = is_killed
|
| 164 |
+
self.killed_animation_frame_cnt = killed_animation_frame_cnt
|
| 165 |
+
self.finished_level_frame_cnt = finished_level_frame_cnt
|
| 166 |
+
self.killed_monster = killed_monster
|
| 167 |
+
self.bumped_head = bumped_head
|
| 168 |
+
self.collected_coin = collected_coin
|
| 169 |
+
self.collected_gem = collected_gem
|
| 170 |
+
self.power_up_mode = power_up_mode
|
| 171 |
+
|
| 172 |
+
self.anim_freq = 5 # hard-coded
|
| 173 |
+
|
| 174 |
+
# decide whether to flip asset horizontally
|
| 175 |
+
self.is_facing_right = True
|
| 176 |
+
if self.vx < 0:
|
| 177 |
+
self.is_facing_right = False
|
| 178 |
+
|
| 179 |
+
# decide which of the two walk/climb asset to use
|
| 180 |
+
self.walk1_mode = True
|
| 181 |
+
if (self.time_alive // self.anim_freq) % 2 != 0:
|
| 182 |
+
self.walk1_mode = False
|
| 183 |
+
|
| 184 |
+
self.pose = self.get_pose()
|
| 185 |
+
|
| 186 |
+
# kwargs are ignored
|
| 187 |
+
# self.__dict__.update(**kwargs)
|
| 188 |
+
|
| 189 |
+
def get_pose(self):
|
| 190 |
+
if self.is_killed:
|
| 191 |
+
return "hit"
|
| 192 |
+
if self.ladder:
|
| 193 |
+
if self.walk1_mode:
|
| 194 |
+
return "climb1"
|
| 195 |
+
else:
|
| 196 |
+
return "climb2"
|
| 197 |
+
if self.vy != 0:
|
| 198 |
+
return "jump"
|
| 199 |
+
if self.spring != 0:
|
| 200 |
+
return "duck"
|
| 201 |
+
if self.vx == 0:
|
| 202 |
+
return "stand"
|
| 203 |
+
if self.walk1_mode:
|
| 204 |
+
return "walk1"
|
| 205 |
+
else:
|
| 206 |
+
return "walk2"
|
| 207 |
+
|
| 208 |
+
def asdict(self):
|
| 209 |
+
return {
|
| 210 |
+
"x": self.x,
|
| 211 |
+
"y": self.y,
|
| 212 |
+
"vx": self.vx,
|
| 213 |
+
"vy": self.vy,
|
| 214 |
+
"time_alive": self.time_alive,
|
| 215 |
+
"ladder": self.ladder,
|
| 216 |
+
"spring": self.spring,
|
| 217 |
+
"is_killed": self.is_killed,
|
| 218 |
+
"killed_animation_frame_cnt": self.killed_animation_frame_cnt,
|
| 219 |
+
"finished_level_frame_cnt": self.finished_level_frame_cnt,
|
| 220 |
+
"killed_monster": self.killed_monster,
|
| 221 |
+
"bumped_head": self.bumped_head,
|
| 222 |
+
"collected_coin": self.collected_coin,
|
| 223 |
+
"collected_gem": self.collected_gem,
|
| 224 |
+
"power_up_mode": self.power_up_mode,
|
| 225 |
+
"anim_freq": self.anim_freq,
|
| 226 |
+
"is_facing_right": self.is_facing_right,
|
| 227 |
+
"walk1_mode": self.walk1_mode,
|
| 228 |
+
"pose": self.pose,
|
| 229 |
+
}
|
| 230 |
+
|
| 231 |
+
def __repr__(self):
|
| 232 |
+
return json.dumps(self.asdict())
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
class Monster:
|
| 236 |
+
def __init__(
|
| 237 |
+
self,
|
| 238 |
+
m_id,
|
| 239 |
+
x,
|
| 240 |
+
y,
|
| 241 |
+
vx=0.0,
|
| 242 |
+
vy=0.0,
|
| 243 |
+
theme=0,
|
| 244 |
+
is_flying=False,
|
| 245 |
+
is_walking=False,
|
| 246 |
+
is_jumping=False,
|
| 247 |
+
is_dead=False,
|
| 248 |
+
time=0,
|
| 249 |
+
anim_freq=1,
|
| 250 |
+
monster_dying_frame_cnt=0,
|
| 251 |
+
**kwargs,
|
| 252 |
+
):
|
| 253 |
+
self.m_id = m_id
|
| 254 |
+
self.x = x
|
| 255 |
+
self.y = y
|
| 256 |
+
self.vx = vx
|
| 257 |
+
self.vy = vy
|
| 258 |
+
self.theme = theme # monster type (saw, snail, slime, etc.)
|
| 259 |
+
self.is_flying = is_flying
|
| 260 |
+
self.is_walking = is_walking
|
| 261 |
+
self.is_jumping = is_jumping
|
| 262 |
+
self.is_dead = is_dead
|
| 263 |
+
self.time = time
|
| 264 |
+
self.anim_freq = anim_freq
|
| 265 |
+
self.monster_dying_frame_cnt = monster_dying_frame_cnt
|
| 266 |
+
|
| 267 |
+
# decide which of the two walk/climb asset to use
|
| 268 |
+
self.walk1_mode = True
|
| 269 |
+
if self.is_jumping:
|
| 270 |
+
# for jumping monster, walk1 asset is decided by vertical speed
|
| 271 |
+
if self.vy != 0:
|
| 272 |
+
self.walk1_mode = False
|
| 273 |
+
elif (self.time // self.anim_freq) % 2 != 0:
|
| 274 |
+
self.walk1_mode = False
|
| 275 |
+
|
| 276 |
+
def asdict(self):
|
| 277 |
+
return {
|
| 278 |
+
"m_id": self.m_id,
|
| 279 |
+
"x": self.x,
|
| 280 |
+
"y": self.y,
|
| 281 |
+
"vx": self.vx,
|
| 282 |
+
"vy": self.vy,
|
| 283 |
+
"theme": self.theme,
|
| 284 |
+
"is_flying": self.is_flying,
|
| 285 |
+
"is_walking": self.is_walking,
|
| 286 |
+
"is_jumping": self.is_jumping,
|
| 287 |
+
"is_dead": self.is_dead,
|
| 288 |
+
"time": self.time,
|
| 289 |
+
"anim_freq": self.anim_freq,
|
| 290 |
+
"monster_dying_frame_cnt": self.monster_dying_frame_cnt,
|
| 291 |
+
"walk1_mode": self.walk1_mode,
|
| 292 |
+
}
|
| 293 |
+
|
| 294 |
+
def __repr__(self):
|
| 295 |
+
return json.dumps(self.asdict())
|
multimodal/examples/mugen/data/coinrun/generate_text_desc.py
ADDED
|
@@ -0,0 +1,435 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import json
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class Sequence:
|
| 11 |
+
def __init__(
|
| 12 |
+
self, start_frame, end_frame, pose_type, start_x, start_y, end_x, end_y
|
| 13 |
+
):
|
| 14 |
+
self.start_frame = start_frame
|
| 15 |
+
self.end_frame = end_frame
|
| 16 |
+
|
| 17 |
+
# 'ground' includes 'walk', 'duck', 'stand'; other types are 'climb', 'jump', 'hit'
|
| 18 |
+
self.pose_type = pose_type
|
| 19 |
+
self.start_x = start_x
|
| 20 |
+
self.start_y = start_y
|
| 21 |
+
self.end_x = end_x
|
| 22 |
+
self.end_y = end_y
|
| 23 |
+
self.time_jumps = 1 if pose_type == "jump" else 0
|
| 24 |
+
self.end_maze_above = "."
|
| 25 |
+
self.end_maze_below = "."
|
| 26 |
+
self.num_coins_eaten = 0
|
| 27 |
+
self.num_gems_eaten = 0
|
| 28 |
+
self.start_shield = False
|
| 29 |
+
self.end_shield = False
|
| 30 |
+
self.changed_shield = False
|
| 31 |
+
self.killed_monsters = []
|
| 32 |
+
self.jump_over_monsters = []
|
| 33 |
+
self.killed_by = ""
|
| 34 |
+
self.text_desc = ""
|
| 35 |
+
|
| 36 |
+
# Decide graduarity of text description (skip sequence shorter than this)
|
| 37 |
+
self.min_len_for_text_desc = 5
|
| 38 |
+
|
| 39 |
+
def asdict(self):
|
| 40 |
+
return {
|
| 41 |
+
"start_frame": self.start_frame,
|
| 42 |
+
"end_frame": self.end_frame,
|
| 43 |
+
"pose_type": self.pose_type,
|
| 44 |
+
"start_xy": (self.start_x, self.start_y),
|
| 45 |
+
"end_xy": (self.end_x, self.end_y),
|
| 46 |
+
"bumped_head": self.is_bumped_head(),
|
| 47 |
+
"same_level_jump": self.is_same_level_jump(),
|
| 48 |
+
"num_coins_eaten": self.num_coins_eaten,
|
| 49 |
+
"num_gems_eaten": self.num_gems_eaten,
|
| 50 |
+
"start_shield": self.start_shield,
|
| 51 |
+
"end_shield": self.end_shield,
|
| 52 |
+
"changed_shield": self.changed_shield,
|
| 53 |
+
"killed_monsters": self.killed_monsters,
|
| 54 |
+
"jump_over_monsters": self.jump_over_monsters,
|
| 55 |
+
"killed_by": self.killed_by,
|
| 56 |
+
"text_desc": self.text_desc,
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
def __repr__(self):
|
| 60 |
+
return json.dumps(self.asdict())
|
| 61 |
+
|
| 62 |
+
# bumped head will show as 'walk' pose and last for 1-2 frames
|
| 63 |
+
def is_bumped_head(self):
|
| 64 |
+
if (
|
| 65 |
+
self.pose_type == "ground"
|
| 66 |
+
and (self.end_frame - self.start_frame <= 1)
|
| 67 |
+
and self.end_maze_below in ".12"
|
| 68 |
+
): # and self.end_maze_above in 'SAab'
|
| 69 |
+
return True
|
| 70 |
+
return False
|
| 71 |
+
|
| 72 |
+
def is_same_level_jump(self):
|
| 73 |
+
if self.pose_type == "jump" and abs(self.end_y - self.start_y) <= 0.5:
|
| 74 |
+
return True
|
| 75 |
+
return False
|
| 76 |
+
|
| 77 |
+
def merge_sequences(self, sequences):
|
| 78 |
+
self.end_frame = sequences[-1].end_frame
|
| 79 |
+
self.end_x = sequences[-1].end_x
|
| 80 |
+
self.end_y = sequences[-1].end_y
|
| 81 |
+
self.end_maze_above = sequences[-1].end_maze_above
|
| 82 |
+
self.end_maze_below = sequences[-1].end_maze_below
|
| 83 |
+
for seq in sequences:
|
| 84 |
+
if seq.is_bumped_head():
|
| 85 |
+
self.time_jumps -= 1
|
| 86 |
+
self.time_jumps += seq.time_jumps
|
| 87 |
+
|
| 88 |
+
self.num_coins_eaten += seq.num_coins_eaten
|
| 89 |
+
self.num_gems_eaten += seq.num_gems_eaten
|
| 90 |
+
self.killed_monsters.extend(seq.killed_monsters)
|
| 91 |
+
self.jump_over_monsters.extend(seq.jump_over_monsters)
|
| 92 |
+
|
| 93 |
+
def process_metadata(self, game):
|
| 94 |
+
# generate game.flattened_monster_names if not already
|
| 95 |
+
# this is used to get monster names
|
| 96 |
+
if len(game.flattened_monster_names) == 0:
|
| 97 |
+
game.flatten_monster_names()
|
| 98 |
+
|
| 99 |
+
# count number of coins and gems eaten during the sequence
|
| 100 |
+
# start from one frame earlier (if not 0) so we can get change in the first frame
|
| 101 |
+
start_frame_id = max(self.start_frame - 1, 0)
|
| 102 |
+
if len(game.frames[self.end_frame].coins_eaten) > len(
|
| 103 |
+
game.frames[start_frame_id].coins_eaten
|
| 104 |
+
):
|
| 105 |
+
start_coin_set = {
|
| 106 |
+
(coord[0], coord[1])
|
| 107 |
+
for coord in game.frames[start_frame_id].coins_eaten
|
| 108 |
+
}
|
| 109 |
+
end_coin_set = {
|
| 110 |
+
(coord[0], coord[1])
|
| 111 |
+
for coord in game.frames[self.end_frame].coins_eaten
|
| 112 |
+
}
|
| 113 |
+
new_coins_eaten = end_coin_set - start_coin_set
|
| 114 |
+
for coin_coord in new_coins_eaten:
|
| 115 |
+
if game.maze[coin_coord[1]][coin_coord[0]] == "2":
|
| 116 |
+
self.num_gems_eaten += 1
|
| 117 |
+
else:
|
| 118 |
+
self.num_coins_eaten += 1
|
| 119 |
+
|
| 120 |
+
# check if Mugen changes between shield up and down mode during the sequence
|
| 121 |
+
self.start_shield = game.frames[self.start_frame].agent.power_up_mode
|
| 122 |
+
self.end_shield = game.frames[self.end_frame].agent.power_up_mode
|
| 123 |
+
shield_up_mode = False
|
| 124 |
+
shield_down_mode = False
|
| 125 |
+
for frame_id in range(self.start_frame, self.end_frame + 1):
|
| 126 |
+
if game.frames[frame_id].agent.power_up_mode:
|
| 127 |
+
shield_up_mode = True
|
| 128 |
+
else:
|
| 129 |
+
shield_down_mode = True
|
| 130 |
+
if shield_up_mode and shield_down_mode:
|
| 131 |
+
self.changed_shield = True
|
| 132 |
+
|
| 133 |
+
end_frame_id = min(self.end_frame + 2, len(game.frames))
|
| 134 |
+
for frame_id in range(self.start_frame, end_frame_id):
|
| 135 |
+
frame = game.frames[frame_id]
|
| 136 |
+
dead_monsters = set()
|
| 137 |
+
for i, m in enumerate(frame.monsters):
|
| 138 |
+
if m.is_dead:
|
| 139 |
+
dead_monsters.add(i)
|
| 140 |
+
# if more monsters are killed, record the monster killed and the frame id
|
| 141 |
+
if frame_id > self.start_frame and len(dead_monsters) > len(
|
| 142 |
+
prev_dead_monsters
|
| 143 |
+
):
|
| 144 |
+
killed_monster_theme = frame.monsters[
|
| 145 |
+
list(dead_monsters - prev_dead_monsters)[0]
|
| 146 |
+
].theme
|
| 147 |
+
self.killed_monsters.append(
|
| 148 |
+
game.flattened_monster_names[killed_monster_theme]
|
| 149 |
+
)
|
| 150 |
+
prev_dead_monsters = dead_monsters.copy()
|
| 151 |
+
|
| 152 |
+
# figure out which monster killed Mugen
|
| 153 |
+
killed_by_m_id = -1
|
| 154 |
+
if self.pose_type == "hit":
|
| 155 |
+
# check the monster distance in the first frame of hit sequence
|
| 156 |
+
m_min_dist = 1000 # just put some random large dist here
|
| 157 |
+
for m in game.frames[self.start_frame].monsters:
|
| 158 |
+
x_dist = self.start_x - m.x
|
| 159 |
+
y_dist = self.start_y - m.y
|
| 160 |
+
m_dist = x_dist * x_dist + y_dist * y_dist
|
| 161 |
+
if m_dist < m_min_dist:
|
| 162 |
+
killed_by_m_id = m.theme
|
| 163 |
+
m_min_dist = m_dist
|
| 164 |
+
if killed_by_m_id != -1:
|
| 165 |
+
self.killed_by = game.flattened_monster_names[killed_by_m_id]
|
| 166 |
+
|
| 167 |
+
# check for monsters jumped over
|
| 168 |
+
if self.pose_type == "jump":
|
| 169 |
+
# for purpose of checking jumped over monsters,
|
| 170 |
+
# ground y is fixed at the y coordinate of the previous frame
|
| 171 |
+
# note for jump sequence, start_y already recorded the location before jump starts
|
| 172 |
+
ground_y = round(self.start_y)
|
| 173 |
+
jump_over_monsters_set = set()
|
| 174 |
+
for frame_id in range(self.start_frame, self.end_frame + 1):
|
| 175 |
+
frame = game.frames[frame_id]
|
| 176 |
+
# this is the location below the agent at the same y level when jump starts
|
| 177 |
+
ground_loc = (round(frame.agent.x), ground_y)
|
| 178 |
+
for i, m in enumerate(frame.monsters):
|
| 179 |
+
if (round(m.x), round(m.y)) == ground_loc:
|
| 180 |
+
# use set to avoid adding duplicates
|
| 181 |
+
jump_over_monsters_set.add(i)
|
| 182 |
+
|
| 183 |
+
# now convert these into names, but only keep those that's still not killed by the next frame
|
| 184 |
+
for m_i in jump_over_monsters_set:
|
| 185 |
+
if not game.frames[end_frame_id - 1].monsters[m_i].is_dead:
|
| 186 |
+
self.jump_over_monsters.append(
|
| 187 |
+
game.flattened_monster_names[frame.monsters[m_i].theme]
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
def generate_text_desc(self):
|
| 191 |
+
# only generate if sequence is long enough
|
| 192 |
+
if self.end_frame - self.start_frame < self.min_len_for_text_desc:
|
| 193 |
+
self.text_desc = ""
|
| 194 |
+
elif self.pose_type == "hit":
|
| 195 |
+
if self.killed_by != "":
|
| 196 |
+
self.text_desc = f"killed by a {self.killed_by}"
|
| 197 |
+
else:
|
| 198 |
+
self.text_desc = "killed by a monster"
|
| 199 |
+
else:
|
| 200 |
+
y_direct = ""
|
| 201 |
+
if self.end_y - self.start_y > 0.5:
|
| 202 |
+
y_direct = " up"
|
| 203 |
+
elif self.start_y - self.end_y > 0.5:
|
| 204 |
+
y_direct = " down"
|
| 205 |
+
else:
|
| 206 |
+
y_direct = " a bit" if self.pose_type == "ground" else ""
|
| 207 |
+
x_direct = ""
|
| 208 |
+
if self.end_x - self.start_x > 0.5:
|
| 209 |
+
x_direct = " to the right"
|
| 210 |
+
elif self.start_x - self.end_x > 0.5:
|
| 211 |
+
x_direct = " to the left"
|
| 212 |
+
else:
|
| 213 |
+
x_direct = " a bit" if self.pose_type == "ground" else ""
|
| 214 |
+
|
| 215 |
+
if self.pose_type == "climb":
|
| 216 |
+
self.text_desc = f"climbs{y_direct} on a ladder"
|
| 217 |
+
elif self.pose_type == "ground":
|
| 218 |
+
self.text_desc = f"walks{x_direct}" # TODO: add random verbs
|
| 219 |
+
elif self.pose_type == "jump":
|
| 220 |
+
jump_time_desc = ""
|
| 221 |
+
if self.time_jumps >= 2:
|
| 222 |
+
jump_time_desc = " a few times"
|
| 223 |
+
|
| 224 |
+
# only add jump destination if it's not a same level jump
|
| 225 |
+
jump_dest_desc = ""
|
| 226 |
+
if y_direct != "":
|
| 227 |
+
if self.end_maze_below in "SAab":
|
| 228 |
+
if self.end_y < 1.5:
|
| 229 |
+
jump_dest_desc = " to the ground"
|
| 230 |
+
else:
|
| 231 |
+
jump_dest_desc = " to a platform"
|
| 232 |
+
elif self.end_maze_below in "#$&%":
|
| 233 |
+
jump_dest_desc = " to a crate"
|
| 234 |
+
elif self.end_maze_below == "=":
|
| 235 |
+
jump_dest_desc = " to a ladder"
|
| 236 |
+
|
| 237 |
+
# add desc for monsters jumped over
|
| 238 |
+
jumped_over_desc = ""
|
| 239 |
+
if len(self.jump_over_monsters) > 0:
|
| 240 |
+
jumped_over_desc = " over a " + " and a ".join(
|
| 241 |
+
self.jump_over_monsters
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
self.text_desc = f"jumps{y_direct}{jump_time_desc}{x_direct}{jumped_over_desc}{jump_dest_desc}"
|
| 245 |
+
|
| 246 |
+
if self.num_coins_eaten > 0 or self.num_gems_eaten > 0:
|
| 247 |
+
self.text_desc += self.generate_collect_coin_desc()
|
| 248 |
+
|
| 249 |
+
if len(self.killed_monsters) > 0:
|
| 250 |
+
self.text_desc += " and killed a " + " and a ".join(
|
| 251 |
+
self.killed_monsters
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
def generate_collect_coin_desc(self):
|
| 255 |
+
if self.num_coins_eaten == 0 and self.num_gems_eaten == 0:
|
| 256 |
+
return ""
|
| 257 |
+
|
| 258 |
+
coin_descs = []
|
| 259 |
+
# add coin description if collected at least one coin
|
| 260 |
+
if self.num_coins_eaten == 1:
|
| 261 |
+
coin_descs.append(" a coin")
|
| 262 |
+
elif self.num_coins_eaten > 1:
|
| 263 |
+
coin_descs.append(" a few coins")
|
| 264 |
+
|
| 265 |
+
# add gem description if collected at least one gem
|
| 266 |
+
if self.num_gems_eaten == 1:
|
| 267 |
+
coin_descs.append(" a gem")
|
| 268 |
+
elif self.num_gems_eaten > 1:
|
| 269 |
+
coin_descs.append(" a few gems")
|
| 270 |
+
|
| 271 |
+
# connects descriptions for coins and gems with 'and'
|
| 272 |
+
coin_descs = " and".join(coin_descs)
|
| 273 |
+
|
| 274 |
+
# shield change should only be a result of eating gem or coin
|
| 275 |
+
if self.changed_shield:
|
| 276 |
+
coin_descs += self.generate_shield_desc()
|
| 277 |
+
|
| 278 |
+
return f" and collects{coin_descs}"
|
| 279 |
+
|
| 280 |
+
def generate_shield_desc(self):
|
| 281 |
+
if not self.start_shield and self.end_shield:
|
| 282 |
+
return " to turn on the shield"
|
| 283 |
+
elif self.start_shield and not self.end_shield:
|
| 284 |
+
return " to turn off the shield"
|
| 285 |
+
else:
|
| 286 |
+
# start and end in the same shield state but still changed shield during sequence
|
| 287 |
+
if self.start_shield:
|
| 288 |
+
return " to turn shield off then on again"
|
| 289 |
+
else:
|
| 290 |
+
return " to turn shield on then off again"
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
def process_sequence(game, curr_pose_type, start_i, curr_i, last_seq=False):
|
| 294 |
+
# different type of pose, construct a sequence
|
| 295 |
+
# for 'jump', the start and end location is based on frame before the first and after the last frame
|
| 296 |
+
# for others, it's the first and last frame
|
| 297 |
+
if curr_pose_type == "jump":
|
| 298 |
+
pos_start_frame = max(start_i - 1, 0)
|
| 299 |
+
pos_end_frame = curr_i
|
| 300 |
+
else:
|
| 301 |
+
pos_start_frame = start_i
|
| 302 |
+
# curr_i will be one frame after, unless it's the last sequence of video
|
| 303 |
+
# however, for jump sequence, we do want one frame after to know where jump lands
|
| 304 |
+
pos_end_frame = curr_i - 1 if not last_seq else curr_i
|
| 305 |
+
|
| 306 |
+
seq = Sequence(
|
| 307 |
+
start_frame=start_i,
|
| 308 |
+
end_frame=curr_i - 1 if not last_seq else curr_i,
|
| 309 |
+
pose_type=curr_pose_type,
|
| 310 |
+
start_x=game.frames[pos_start_frame].agent.x,
|
| 311 |
+
start_y=game.frames[pos_start_frame].agent.y,
|
| 312 |
+
end_x=game.frames[pos_end_frame].agent.x,
|
| 313 |
+
end_y=game.frames[pos_end_frame].agent.y,
|
| 314 |
+
)
|
| 315 |
+
seq.end_maze_above = game.maze[round(seq.end_y) + 1][round(seq.end_x)]
|
| 316 |
+
seq.end_maze_below = game.maze[round(seq.end_y) - 1][round(seq.end_x)]
|
| 317 |
+
# sometimes jump may end a bit over the edge of cliff, this is to catch and fix that
|
| 318 |
+
if curr_pose_type == "jump" and seq.end_maze_below in ".12":
|
| 319 |
+
neighbor_x = (
|
| 320 |
+
int(seq.end_x) * 2 + 1 - round(seq.end_x)
|
| 321 |
+
) # get the opposite of round()
|
| 322 |
+
seq.end_maze_below = game.maze[round(seq.end_y) - 1][neighbor_x]
|
| 323 |
+
|
| 324 |
+
return seq
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
def convert_game_to_text_desc(game, start_idx=0, end_idx=-1, alien_name="Mugen"):
|
| 328 |
+
if alien_name is None:
|
| 329 |
+
alien_name = "Mugen"
|
| 330 |
+
|
| 331 |
+
# if end_idx is not specified, set it to end of the game level
|
| 332 |
+
if end_idx == -1:
|
| 333 |
+
end_idx = len(game.frames)
|
| 334 |
+
start_idx = max(0, start_idx)
|
| 335 |
+
end_idx = min(len(game.frames), end_idx)
|
| 336 |
+
|
| 337 |
+
sequences = []
|
| 338 |
+
for i, f in enumerate(game.frames[start_idx:end_idx]):
|
| 339 |
+
pose = f.agent.pose.strip("12")
|
| 340 |
+
if pose in ["walk", "duck", "stand"]:
|
| 341 |
+
pose_type = "ground"
|
| 342 |
+
else:
|
| 343 |
+
pose_type = pose
|
| 344 |
+
if i == 0:
|
| 345 |
+
# first frame, initialize some status
|
| 346 |
+
start_i = 0
|
| 347 |
+
curr_pose_type = pose_type
|
| 348 |
+
continue
|
| 349 |
+
|
| 350 |
+
if pose_type == curr_pose_type:
|
| 351 |
+
# same type of pose, same sequence
|
| 352 |
+
continue
|
| 353 |
+
else:
|
| 354 |
+
seq = process_sequence(
|
| 355 |
+
game, curr_pose_type, start_idx + start_i, start_idx + i, last_seq=False
|
| 356 |
+
)
|
| 357 |
+
sequences.append(seq)
|
| 358 |
+
start_i = i
|
| 359 |
+
curr_pose_type = pose_type
|
| 360 |
+
|
| 361 |
+
# add the last leftover sequence
|
| 362 |
+
seq = process_sequence(
|
| 363 |
+
game, curr_pose_type, start_idx + start_i, start_idx + i, last_seq=True
|
| 364 |
+
)
|
| 365 |
+
sequences.append(seq)
|
| 366 |
+
|
| 367 |
+
# collapse two jumps into one sequence
|
| 368 |
+
# first pass, merge jumps before and after bumped head, this is to correctly identify jumps at the same level
|
| 369 |
+
seq_i = 0
|
| 370 |
+
reduced_sequences = []
|
| 371 |
+
while seq_i < len(sequences):
|
| 372 |
+
if seq_i == 0 or seq_i == len(sequences) - 1:
|
| 373 |
+
reduced_sequences.append(sequences[seq_i])
|
| 374 |
+
seq_i += 1
|
| 375 |
+
elif (
|
| 376 |
+
sequences[seq_i].is_bumped_head()
|
| 377 |
+
and reduced_sequences[-1].pose_type == "jump"
|
| 378 |
+
and sequences[seq_i + 1].pose_type == "jump"
|
| 379 |
+
):
|
| 380 |
+
# in case of bumped head, merge the jumps before and after
|
| 381 |
+
reduced_sequences[-1].merge_sequences(sequences[seq_i : seq_i + 2])
|
| 382 |
+
seq_i += 2
|
| 383 |
+
else:
|
| 384 |
+
reduced_sequences.append(sequences[seq_i])
|
| 385 |
+
seq_i += 1
|
| 386 |
+
sequences = reduced_sequences
|
| 387 |
+
|
| 388 |
+
# second pass, collapse two jumps into one sequence if they're both same level jumps
|
| 389 |
+
# jump up and down are not merged (unless it's separated by bumped head that will be merged in first pass)
|
| 390 |
+
result_sequences = []
|
| 391 |
+
seq_i = 0
|
| 392 |
+
max_ground_seq_len_to_merge = 5
|
| 393 |
+
while seq_i < len(sequences):
|
| 394 |
+
# only merge if it's a 'ground' sequence, and before/after are both jumps
|
| 395 |
+
if (
|
| 396 |
+
sequences[seq_i].pose_type != "ground"
|
| 397 |
+
or seq_i == 0
|
| 398 |
+
or seq_i == len(sequences) - 1
|
| 399 |
+
):
|
| 400 |
+
result_sequences.append(sequences[seq_i])
|
| 401 |
+
seq_i += 1
|
| 402 |
+
elif (
|
| 403 |
+
result_sequences[-1].pose_type != "jump"
|
| 404 |
+
or sequences[seq_i + 1].pose_type != "jump"
|
| 405 |
+
):
|
| 406 |
+
result_sequences.append(sequences[seq_i])
|
| 407 |
+
seq_i += 1
|
| 408 |
+
elif (
|
| 409 |
+
result_sequences[-1].is_same_level_jump()
|
| 410 |
+
and sequences[seq_i + 1].is_same_level_jump()
|
| 411 |
+
and (
|
| 412 |
+
sequences[seq_i].end_frame - sequences[seq_i].start_frame
|
| 413 |
+
< max_ground_seq_len_to_merge
|
| 414 |
+
)
|
| 415 |
+
):
|
| 416 |
+
# not bumped head, then only merge if sequence is short enough, and both jumps are the same level
|
| 417 |
+
result_sequences[-1].merge_sequences(sequences[seq_i : seq_i + 2])
|
| 418 |
+
seq_i += 2
|
| 419 |
+
else:
|
| 420 |
+
result_sequences.append(sequences[seq_i])
|
| 421 |
+
seq_i += 1
|
| 422 |
+
sequences = result_sequences
|
| 423 |
+
|
| 424 |
+
# generate text description for each sequence
|
| 425 |
+
text_descriptions = []
|
| 426 |
+
for seq in sequences:
|
| 427 |
+
seq.process_metadata(game)
|
| 428 |
+
seq.generate_text_desc()
|
| 429 |
+
if seq.text_desc != "":
|
| 430 |
+
text_descriptions.append(seq.text_desc)
|
| 431 |
+
|
| 432 |
+
# add Mugen in the beginning, then concat by 'and'
|
| 433 |
+
final_text_desc = alien_name + " " + ", and ".join(text_descriptions)
|
| 434 |
+
|
| 435 |
+
return final_text_desc
|
multimodal/examples/mugen/data/mugen_datamodules.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from typing import Callable, Optional
|
| 8 |
+
|
| 9 |
+
import pytorch_lightning as pl
|
| 10 |
+
import torch
|
| 11 |
+
import torch.distributed as dist
|
| 12 |
+
import torch.utils.data as data
|
| 13 |
+
|
| 14 |
+
from .mugen_dataset import MUGENDataset, MUGENDatasetArgs
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class MUGENDataModule(pl.LightningDataModule):
|
| 18 |
+
"""General lightning data module for MUGEN dataset.
|
| 19 |
+
|
| 20 |
+
Args:
|
| 21 |
+
mugen_dataset_args (MUGENDatasetArgs): arguments for MUGENDataset.
|
| 22 |
+
text_transform (Optional[Callable]): transform for text batches.
|
| 23 |
+
Only used when not ``None`` and when ``mugen_dataset_args.get_text_desc = True``.
|
| 24 |
+
Defaults to ``None``.
|
| 25 |
+
video_transform (Optional[Callable]): transform for video batches.
|
| 26 |
+
Only used when not ``None`` and when ``mugen_dataset_args.get_game_frame = True``.
|
| 27 |
+
Defaults to ``None``.
|
| 28 |
+
audio_transform (Optional[Callable]): transform for audio batches.
|
| 29 |
+
Only used when not ``None`` and when ``mugen_dataset_args.get_audio = True``.
|
| 30 |
+
Defaults to ``None``.
|
| 31 |
+
batch_size (int): number of samples per batch.
|
| 32 |
+
Defaults to ``16``.
|
| 33 |
+
num_workers (int): number of subprocesses for data loading.
|
| 34 |
+
Defaults to ``0``, meaning data is loaded in the main process.
|
| 35 |
+
shuffle (bool): whether to reshuffle data after each epoch.
|
| 36 |
+
Defaults to ``True``.
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
def __init__(
|
| 40 |
+
self,
|
| 41 |
+
mugen_dataset_args: MUGENDatasetArgs,
|
| 42 |
+
text_transform: Optional[Callable] = None,
|
| 43 |
+
video_transform: Optional[Callable] = None,
|
| 44 |
+
audio_transform: Optional[Callable] = None,
|
| 45 |
+
batch_size: int = 16,
|
| 46 |
+
num_workers: int = 0,
|
| 47 |
+
shuffle: bool = True,
|
| 48 |
+
):
|
| 49 |
+
super().__init__()
|
| 50 |
+
self.mugen_dataset_args = mugen_dataset_args
|
| 51 |
+
self.text_transform = text_transform
|
| 52 |
+
self.video_transform = video_transform
|
| 53 |
+
self.audio_transform = audio_transform
|
| 54 |
+
self.batch_size = batch_size
|
| 55 |
+
self.num_workers = num_workers
|
| 56 |
+
self.shuffle = shuffle
|
| 57 |
+
|
| 58 |
+
@property
|
| 59 |
+
def n_classes(self):
|
| 60 |
+
dataset = self._dataset(True)
|
| 61 |
+
return dataset.n_classes
|
| 62 |
+
|
| 63 |
+
def _custom_collate_fn(self, batch):
|
| 64 |
+
collated_batch = {}
|
| 65 |
+
if self.mugen_dataset_args.get_game_frame:
|
| 66 |
+
video = [elem["video"] for elem in batch]
|
| 67 |
+
video = torch.stack(video)
|
| 68 |
+
video = self.video_transform(video) if self.video_transform else video
|
| 69 |
+
collated_batch["video"] = video
|
| 70 |
+
if self.mugen_dataset_args.get_text_desc:
|
| 71 |
+
text = [elem["text"] for elem in batch]
|
| 72 |
+
# cannot be torch.stack'ed because still in raw text form, not Tensor
|
| 73 |
+
text = self.text_transform(text) if self.text_transform else text
|
| 74 |
+
collated_batch["text"] = text
|
| 75 |
+
if self.mugen_dataset_args.get_audio:
|
| 76 |
+
audio = [elem["audio"] for elem in batch]
|
| 77 |
+
audio = torch.stack(audio)
|
| 78 |
+
audio = self.audio_transform(audio) if self.audio_transform else audio
|
| 79 |
+
collated_batch["audio"] = audio
|
| 80 |
+
return collated_batch
|
| 81 |
+
|
| 82 |
+
def _dataset(self, split):
|
| 83 |
+
dataset = MUGENDataset(args=self.mugen_dataset_args, split=split)
|
| 84 |
+
return dataset
|
| 85 |
+
|
| 86 |
+
def _dataloader(self, split):
|
| 87 |
+
dataset = self._dataset(split)
|
| 88 |
+
if dist.is_initialized():
|
| 89 |
+
sampler = data.distributed.DistributedSampler(
|
| 90 |
+
dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank()
|
| 91 |
+
)
|
| 92 |
+
else:
|
| 93 |
+
sampler = None
|
| 94 |
+
dataloader = data.DataLoader(
|
| 95 |
+
dataset,
|
| 96 |
+
batch_size=self.batch_size,
|
| 97 |
+
num_workers=self.num_workers,
|
| 98 |
+
pin_memory=True,
|
| 99 |
+
sampler=sampler,
|
| 100 |
+
shuffle=sampler is None and self.shuffle is True,
|
| 101 |
+
collate_fn=self._custom_collate_fn,
|
| 102 |
+
)
|
| 103 |
+
return dataloader
|
| 104 |
+
|
| 105 |
+
def train_dataloader(self):
|
| 106 |
+
return self._dataloader("train")
|
| 107 |
+
|
| 108 |
+
def val_dataloader(self):
|
| 109 |
+
return self._dataloader("val")
|
| 110 |
+
|
| 111 |
+
def test_dataloader(self):
|
| 112 |
+
return self._dataloader("test")
|
multimodal/examples/mugen/generation/LoadAndComparePretrainedVQVAE.ipynb
ADDED
|
@@ -0,0 +1,383 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"id": "ee3d68e4",
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"source": [
|
| 8 |
+
"# Compare MUGEN's Video VQVAE with TorchMultimodal's\n",
|
| 9 |
+
"\n",
|
| 10 |
+
"This notebook loads the public MUGEN checkpoint for Video VQVAE, remaps the state_dict, and loads it into TorchMultimodal's Video VQVAE to ensure the outputs match. "
|
| 11 |
+
]
|
| 12 |
+
},
|
| 13 |
+
{
|
| 14 |
+
"cell_type": "markdown",
|
| 15 |
+
"id": "5af9d001",
|
| 16 |
+
"metadata": {},
|
| 17 |
+
"source": [
|
| 18 |
+
"### Set directories\n",
|
| 19 |
+
"\n",
|
| 20 |
+
"Replace these with your local directories."
|
| 21 |
+
]
|
| 22 |
+
},
|
| 23 |
+
{
|
| 24 |
+
"cell_type": "code",
|
| 25 |
+
"execution_count": 2,
|
| 26 |
+
"id": "071c8b48",
|
| 27 |
+
"metadata": {},
|
| 28 |
+
"outputs": [],
|
| 29 |
+
"source": [
|
| 30 |
+
"checkpoint_dir = '/Users/rafiayub/checkpoints/'\n",
|
| 31 |
+
"repo_dir = '/Users/rafiayub/mugen/'\n",
|
| 32 |
+
"home_dir = '/Users/rafiayub/'"
|
| 33 |
+
]
|
| 34 |
+
},
|
| 35 |
+
{
|
| 36 |
+
"cell_type": "markdown",
|
| 37 |
+
"id": "a3a0f19f",
|
| 38 |
+
"metadata": {},
|
| 39 |
+
"source": [
|
| 40 |
+
"### Clone MUGEN's repo"
|
| 41 |
+
]
|
| 42 |
+
},
|
| 43 |
+
{
|
| 44 |
+
"cell_type": "code",
|
| 45 |
+
"execution_count": null,
|
| 46 |
+
"id": "83812502",
|
| 47 |
+
"metadata": {},
|
| 48 |
+
"outputs": [],
|
| 49 |
+
"source": [
|
| 50 |
+
"!git clone https://github.com/mugen-org/MUGEN_baseline.git $repo_dir"
|
| 51 |
+
]
|
| 52 |
+
},
|
| 53 |
+
{
|
| 54 |
+
"cell_type": "markdown",
|
| 55 |
+
"id": "07757cfa",
|
| 56 |
+
"metadata": {},
|
| 57 |
+
"source": [
|
| 58 |
+
"### Download and unzip checkpoints\n",
|
| 59 |
+
"\n",
|
| 60 |
+
"This will take some time."
|
| 61 |
+
]
|
| 62 |
+
},
|
| 63 |
+
{
|
| 64 |
+
"cell_type": "code",
|
| 65 |
+
"execution_count": null,
|
| 66 |
+
"id": "d41a0c86",
|
| 67 |
+
"metadata": {},
|
| 68 |
+
"outputs": [],
|
| 69 |
+
"source": [
|
| 70 |
+
"!wget https://dl.noahmt.com/creativity/data/MUGEN_release/checkpoints.zip -P $checkpoint_dir"
|
| 71 |
+
]
|
| 72 |
+
},
|
| 73 |
+
{
|
| 74 |
+
"cell_type": "code",
|
| 75 |
+
"execution_count": null,
|
| 76 |
+
"id": "01d9638a",
|
| 77 |
+
"metadata": {},
|
| 78 |
+
"outputs": [],
|
| 79 |
+
"source": [
|
| 80 |
+
"import os\n",
|
| 81 |
+
"\n",
|
| 82 |
+
"# Unzip checkpoints\n",
|
| 83 |
+
"zip_location = os.path.join(checkpoint_dir, 'checkpoints.zip')\n",
|
| 84 |
+
"!unzip $zip_location -d $checkpoint_dir"
|
| 85 |
+
]
|
| 86 |
+
},
|
| 87 |
+
{
|
| 88 |
+
"cell_type": "markdown",
|
| 89 |
+
"id": "f06c8938",
|
| 90 |
+
"metadata": {},
|
| 91 |
+
"source": [
|
| 92 |
+
"### Load checkpoint into MUGEN model"
|
| 93 |
+
]
|
| 94 |
+
},
|
| 95 |
+
{
|
| 96 |
+
"cell_type": "code",
|
| 97 |
+
"execution_count": 3,
|
| 98 |
+
"id": "f3e74b3a",
|
| 99 |
+
"metadata": {},
|
| 100 |
+
"outputs": [],
|
| 101 |
+
"source": [
|
| 102 |
+
"import sys\n",
|
| 103 |
+
"import os\n",
|
| 104 |
+
"sys.path.append(home_dir)\n",
|
| 105 |
+
"\n",
|
| 106 |
+
"import torch\n",
|
| 107 |
+
"from torch import nn\n",
|
| 108 |
+
"import mugen\n",
|
| 109 |
+
"\n",
|
| 110 |
+
"ckpt = torch.load(\n",
|
| 111 |
+
" os.path.join(checkpoint_dir, 'generation/video_vqvae/L32/epoch=54-step=599999.ckpt'), \n",
|
| 112 |
+
" map_location=torch.device('cpu')\n",
|
| 113 |
+
")"
|
| 114 |
+
]
|
| 115 |
+
},
|
| 116 |
+
{
|
| 117 |
+
"cell_type": "markdown",
|
| 118 |
+
"id": "3ea6d13e",
|
| 119 |
+
"metadata": {},
|
| 120 |
+
"source": [
|
| 121 |
+
"The arguments are taken from MUGEN's training scripts found at: https://github.com/mugen-org/MUGEN_baseline/blob/main/generation/experiments/vqvae/VideoVQVAE_L32.sh"
|
| 122 |
+
]
|
| 123 |
+
},
|
| 124 |
+
{
|
| 125 |
+
"cell_type": "code",
|
| 126 |
+
"execution_count": 4,
|
| 127 |
+
"id": "f81bea2e",
|
| 128 |
+
"metadata": {},
|
| 129 |
+
"outputs": [],
|
| 130 |
+
"source": [
|
| 131 |
+
"class Namespace:\n",
|
| 132 |
+
" def __init__(self, **kwargs):\n",
|
| 133 |
+
" self.__dict__.update(kwargs)\n",
|
| 134 |
+
"\n",
|
| 135 |
+
"\n",
|
| 136 |
+
"vqvae_args=Namespace(\n",
|
| 137 |
+
" embedding_dim=256,\n",
|
| 138 |
+
" n_codes=2048,\n",
|
| 139 |
+
" n_hiddens=240,\n",
|
| 140 |
+
" n_res_layers=4,\n",
|
| 141 |
+
" lr=0.0003,\n",
|
| 142 |
+
" downsample=(4, 32, 32),\n",
|
| 143 |
+
" kernel_size=3,\n",
|
| 144 |
+
" sequence_length=16,\n",
|
| 145 |
+
" resolution=256,\n",
|
| 146 |
+
")\n",
|
| 147 |
+
"vv_mugen = mugen.VQVAE(vqvae_args)"
|
| 148 |
+
]
|
| 149 |
+
},
|
| 150 |
+
{
|
| 151 |
+
"cell_type": "code",
|
| 152 |
+
"execution_count": 5,
|
| 153 |
+
"id": "fbdcf1f6",
|
| 154 |
+
"metadata": {},
|
| 155 |
+
"outputs": [
|
| 156 |
+
{
|
| 157 |
+
"data": {
|
| 158 |
+
"text/plain": [
|
| 159 |
+
"<All keys matched successfully>"
|
| 160 |
+
]
|
| 161 |
+
},
|
| 162 |
+
"execution_count": 5,
|
| 163 |
+
"metadata": {},
|
| 164 |
+
"output_type": "execute_result"
|
| 165 |
+
}
|
| 166 |
+
],
|
| 167 |
+
"source": [
|
| 168 |
+
"vv_mugen.load_state_dict(ckpt['state_dict'])"
|
| 169 |
+
]
|
| 170 |
+
},
|
| 171 |
+
{
|
| 172 |
+
"cell_type": "markdown",
|
| 173 |
+
"id": "a6bfb325",
|
| 174 |
+
"metadata": {},
|
| 175 |
+
"source": [
|
| 176 |
+
"### Create TorchMultimodal's Video VQVAE"
|
| 177 |
+
]
|
| 178 |
+
},
|
| 179 |
+
{
|
| 180 |
+
"cell_type": "code",
|
| 181 |
+
"execution_count": 6,
|
| 182 |
+
"id": "74e6bd54",
|
| 183 |
+
"metadata": {},
|
| 184 |
+
"outputs": [],
|
| 185 |
+
"source": [
|
| 186 |
+
"from examples.mugen.generation.video_vqvae import video_vqvae_mugen\n",
|
| 187 |
+
"\n",
|
| 188 |
+
"vv_torchmm = video_vqvae_mugen(pretrained_model_key=None)"
|
| 189 |
+
]
|
| 190 |
+
},
|
| 191 |
+
{
|
| 192 |
+
"cell_type": "markdown",
|
| 193 |
+
"id": "e612d831",
|
| 194 |
+
"metadata": {},
|
| 195 |
+
"source": [
|
| 196 |
+
"### Remap MUGEN's state_dict and load into new model"
|
| 197 |
+
]
|
| 198 |
+
},
|
| 199 |
+
{
|
| 200 |
+
"cell_type": "code",
|
| 201 |
+
"execution_count": 7,
|
| 202 |
+
"id": "5f4d4774",
|
| 203 |
+
"metadata": {},
|
| 204 |
+
"outputs": [],
|
| 205 |
+
"source": [
|
| 206 |
+
"import re\n",
|
| 207 |
+
"\n",
|
| 208 |
+
"def map_state_dict(state_dict):\n",
|
| 209 |
+
" mapped_state_dict = {}\n",
|
| 210 |
+
" dim_map = {'w': '2', 'h': '1', 't': '0'}\n",
|
| 211 |
+
" layer_map = {'w_qs': 'query', 'w_ks': 'key', 'w_vs': 'value', 'fc': 'output'}\n",
|
| 212 |
+
" for param, val in state_dict.items():\n",
|
| 213 |
+
" new_param = param\n",
|
| 214 |
+
" res = re.search('encoder.convs.', param)\n",
|
| 215 |
+
" if res:\n",
|
| 216 |
+
" idx = res.end()\n",
|
| 217 |
+
" layer_id = int(param[idx])\n",
|
| 218 |
+
" new_param = param[:idx] + str(layer_id * 2) + param[idx+1:]\n",
|
| 219 |
+
" mapped_state_dict[new_param] = val\n",
|
| 220 |
+
" continue\n",
|
| 221 |
+
" res = re.search('encoder.conv_last', param)\n",
|
| 222 |
+
" if res:\n",
|
| 223 |
+
" idx = res.start() + len('encoder.')\n",
|
| 224 |
+
" new_param = param[:idx] + 'convs.10' + param[res.end():]\n",
|
| 225 |
+
" mapped_state_dict[new_param] = val\n",
|
| 226 |
+
" continue\n",
|
| 227 |
+
" res = re.search('attn_[w,h,t]\\..*\\.', param)\n",
|
| 228 |
+
" if res:\n",
|
| 229 |
+
" dim = param[res.start()+5]\n",
|
| 230 |
+
" new_dim = dim_map[dim]\n",
|
| 231 |
+
" layer = param[res.start()+7:res.end()-1]\n",
|
| 232 |
+
" new_layer = layer_map[layer]\n",
|
| 233 |
+
" new_param = param[:res.start()] + 'mha_attns.' + new_dim + '.' + new_layer + '.' + param[res.end():]\n",
|
| 234 |
+
" mapped_state_dict[new_param] = val\n",
|
| 235 |
+
" continue\n",
|
| 236 |
+
" res = re.search('pre_vq_conv', param)\n",
|
| 237 |
+
" if res:\n",
|
| 238 |
+
" new_param = 'encoder.conv_out' + param[res.end():]\n",
|
| 239 |
+
" mapped_state_dict[new_param] = val\n",
|
| 240 |
+
" continue\n",
|
| 241 |
+
" res = re.search('post_vq_conv', param)\n",
|
| 242 |
+
" if res:\n",
|
| 243 |
+
" new_param = 'decoder.conv_in' + param[res.end():]\n",
|
| 244 |
+
" mapped_state_dict[new_param] = val\n",
|
| 245 |
+
" continue\n",
|
| 246 |
+
" res = re.search('decoder.convts.', param)\n",
|
| 247 |
+
" if res:\n",
|
| 248 |
+
" idx = res.end()\n",
|
| 249 |
+
" layer_id = int(param[idx])\n",
|
| 250 |
+
" new_param = param[:idx] + str(layer_id * 2) + param[idx+1:]\n",
|
| 251 |
+
" mapped_state_dict[new_param] = val\n",
|
| 252 |
+
" continue\n",
|
| 253 |
+
" if param == 'codebook.N':\n",
|
| 254 |
+
" new_param = 'codebook.code_usage'\n",
|
| 255 |
+
" mapped_state_dict[new_param] = val\n",
|
| 256 |
+
" continue\n",
|
| 257 |
+
" if param == 'codebook.z_avg':\n",
|
| 258 |
+
" new_param = 'codebook.code_avg'\n",
|
| 259 |
+
" mapped_state_dict[new_param] = val\n",
|
| 260 |
+
" continue\n",
|
| 261 |
+
" if param == 'codebook.embeddings':\n",
|
| 262 |
+
" new_param = 'codebook.embedding'\n",
|
| 263 |
+
" mapped_state_dict[new_param] = val\n",
|
| 264 |
+
" continue\n",
|
| 265 |
+
" \n",
|
| 266 |
+
" mapped_state_dict[new_param] = val\n",
|
| 267 |
+
" \n",
|
| 268 |
+
" return mapped_state_dict"
|
| 269 |
+
]
|
| 270 |
+
},
|
| 271 |
+
{
|
| 272 |
+
"cell_type": "code",
|
| 273 |
+
"execution_count": 8,
|
| 274 |
+
"id": "38234858",
|
| 275 |
+
"metadata": {},
|
| 276 |
+
"outputs": [],
|
| 277 |
+
"source": [
|
| 278 |
+
"new_state_dict = map_state_dict(ckpt['state_dict'])"
|
| 279 |
+
]
|
| 280 |
+
},
|
| 281 |
+
{
|
| 282 |
+
"cell_type": "code",
|
| 283 |
+
"execution_count": 9,
|
| 284 |
+
"id": "e160fb51",
|
| 285 |
+
"metadata": {
|
| 286 |
+
"scrolled": false
|
| 287 |
+
},
|
| 288 |
+
"outputs": [
|
| 289 |
+
{
|
| 290 |
+
"data": {
|
| 291 |
+
"text/plain": [
|
| 292 |
+
"<All keys matched successfully>"
|
| 293 |
+
]
|
| 294 |
+
},
|
| 295 |
+
"execution_count": 9,
|
| 296 |
+
"metadata": {},
|
| 297 |
+
"output_type": "execute_result"
|
| 298 |
+
}
|
| 299 |
+
],
|
| 300 |
+
"source": [
|
| 301 |
+
"vv_torchmm.load_state_dict(new_state_dict)"
|
| 302 |
+
]
|
| 303 |
+
},
|
| 304 |
+
{
|
| 305 |
+
"cell_type": "markdown",
|
| 306 |
+
"id": "46d58eb7",
|
| 307 |
+
"metadata": {},
|
| 308 |
+
"source": [
|
| 309 |
+
"### Compare outputs with a random input"
|
| 310 |
+
]
|
| 311 |
+
},
|
| 312 |
+
{
|
| 313 |
+
"cell_type": "code",
|
| 314 |
+
"execution_count": 10,
|
| 315 |
+
"id": "3c85cdd3",
|
| 316 |
+
"metadata": {},
|
| 317 |
+
"outputs": [
|
| 318 |
+
{
|
| 319 |
+
"name": "stdout",
|
| 320 |
+
"output_type": "stream",
|
| 321 |
+
"text": [
|
| 322 |
+
"Max difference between outputs: 3.0875205993652344e-05\n",
|
| 323 |
+
"Mean difference between outputs: 1.7353995929170196e-07\n"
|
| 324 |
+
]
|
| 325 |
+
}
|
| 326 |
+
],
|
| 327 |
+
"source": [
|
| 328 |
+
"torch.manual_seed(4)\n",
|
| 329 |
+
"video = torch.randn(1,3,32,256,256) # b, c, t, h, w\n",
|
| 330 |
+
"\n",
|
| 331 |
+
"vv_mugen.eval()\n",
|
| 332 |
+
"vv_torchmm.eval()\n",
|
| 333 |
+
"\n",
|
| 334 |
+
"loss, x_recon, codebook_output = vv_mugen(video)\n",
|
| 335 |
+
"output = vv_torchmm(video)\n",
|
| 336 |
+
"\n",
|
| 337 |
+
"diff = abs(output.decoded - x_recon)\n",
|
| 338 |
+
"print(f'Max difference between outputs: {torch.max(diff).item()}')\n",
|
| 339 |
+
"print(f'Mean difference between outputs: {torch.mean(diff).item()}')"
|
| 340 |
+
]
|
| 341 |
+
},
|
| 342 |
+
{
|
| 343 |
+
"cell_type": "markdown",
|
| 344 |
+
"id": "fa78569e",
|
| 345 |
+
"metadata": {},
|
| 346 |
+
"source": [
|
| 347 |
+
"### Save mapped checkpoint"
|
| 348 |
+
]
|
| 349 |
+
},
|
| 350 |
+
{
|
| 351 |
+
"cell_type": "code",
|
| 352 |
+
"execution_count": 9,
|
| 353 |
+
"id": "48651d44",
|
| 354 |
+
"metadata": {},
|
| 355 |
+
"outputs": [],
|
| 356 |
+
"source": [
|
| 357 |
+
"save_path = '/Users/rafiayub/checkpoints/generation/video_vqvae/mugen_video_vqvae_L32.pt'\n",
|
| 358 |
+
"torch.save(new_state_dict, save_path)"
|
| 359 |
+
]
|
| 360 |
+
}
|
| 361 |
+
],
|
| 362 |
+
"metadata": {
|
| 363 |
+
"kernelspec": {
|
| 364 |
+
"display_name": "Python 3 (ipykernel)",
|
| 365 |
+
"language": "python",
|
| 366 |
+
"name": "python3"
|
| 367 |
+
},
|
| 368 |
+
"language_info": {
|
| 369 |
+
"codemirror_mode": {
|
| 370 |
+
"name": "ipython",
|
| 371 |
+
"version": 3
|
| 372 |
+
},
|
| 373 |
+
"file_extension": ".py",
|
| 374 |
+
"mimetype": "text/x-python",
|
| 375 |
+
"name": "python",
|
| 376 |
+
"nbconvert_exporter": "python",
|
| 377 |
+
"pygments_lexer": "ipython3",
|
| 378 |
+
"version": "3.9.12"
|
| 379 |
+
}
|
| 380 |
+
},
|
| 381 |
+
"nbformat": 4,
|
| 382 |
+
"nbformat_minor": 5
|
| 383 |
+
}
|
multimodal/examples/mugen/generation/README.md
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Text-to-Video Generation with MUGEN
|
| 2 |
+
|
| 3 |
+
This directory contains the high-level model components for text-to-video generation following [MUGEN](https://arxiv.org/abs/2204.08058). They demonstrate how to use building blocks from TorchMultimodal to quickly assemble a new auto-regressive generative model for different pairs of modalities. Here is a [colab demo](https://colab.research.google.com/drive/1C3ZbH_l19g_KqW3CPeX2-8Q2sOUCpmZo?usp=sharing) showing how to generate a video clip from text prompts.
|
| 4 |
+
|
| 5 |
+
https://user-images.githubusercontent.com/23155714/196074330-6f03593c-da8e-473f-8935-8bf1950baa33.mp4
|
| 6 |
+
|
| 7 |
+
```python
|
| 8 |
+
from torchmultimodal.utils.generate import GenerationUtil
|
| 9 |
+
from examples.mugen.generation.text_video_gpt import text_video_gpt
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
model = text_video_gpt(video_seq_len=32, pretrained_text_video_gpt_model_key="mugen_L32")
|
| 13 |
+
generator = GenerationUtil(model)
|
| 14 |
+
|
| 15 |
+
output = generator.sample(
|
| 16 |
+
['Mugen moves left to right on a cliff and picks up a gem.'],
|
| 17 |
+
max_seq_len=512,
|
| 18 |
+
use_cache=True,
|
| 19 |
+
causal=True,
|
| 20 |
+
device=<current_device>,
|
| 21 |
+
)
|
| 22 |
+
samples = output.decoded
|
| 23 |
+
```
|
| 24 |
+
|
| 25 |
+
## Model
|
| 26 |
+
The model architecture used by MUGEN follows [DALL-E](https://arxiv.org/abs/2102.12092) but with the image components replaced by those for video following [VideoGPT](https://arxiv.org/abs/2104.10157).
|
| 27 |
+
|
| 28 |
+
Multimodal generation involves generation of samples in one modality given inputs from another. As in the text-to-image generation model DALL-E, it typically involves a two-stage process of first learning a discrete latent representation for each modality and then using a [GPT](https://openai.com/blog/language-unsupervised/) transformer decoder to learn a joint prior for both modalities in the latent space. For text data, the latent representation is obtained through tokenization such as [BPE](https://en.wikipedia.org/wiki/Byte_pair_encoding) used in this example. For high dimensional data such as video and image, a [VQ-VAE](https://arxiv.org/abs/1711.00937) model is used to learn a set of downsampled discrete embedding vectors through nearest-neighbor lookups from a "codebook" where the chosen indices are referred to as the token ids following convention from language modeling.
|
| 29 |
+
|
| 30 |
+
VideoGPT is a generative model for video using a VQ-VAE model with video encoder/decoder and a GPT transformer decoder for token generation. The encoder and the decoder use 3D-convolution and self axial-attention to learn video information.
|
| 31 |
+
|
| 32 |
+
## Generation
|
| 33 |
+
In this example generation refers to the auto-regressive process where we iteratively predict the next token id from the current until reaching the desired output length, a technique initially used by language modeling but has been extended to multimodal generation. To control the generation process, a top level abstraction is provided as a utility in [generate.py](https://github.com/facebookresearch/multimodal/blob/main/torchmultimodal/utils/generate.py) which takes the model as an input.
|
multimodal/examples/mugen/generation/text_video_gpt.py
ADDED
|
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from typing import List, Optional, Tuple
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
from examples.mugen.generation.video_vqvae import video_vqvae_mugen
|
| 12 |
+
|
| 13 |
+
from torch import nn, Tensor
|
| 14 |
+
|
| 15 |
+
from torchmultimodal.models.video_gpt.gpt import (
|
| 16 |
+
MultimodalGPT,
|
| 17 |
+
MultimodalTransformerDecoder,
|
| 18 |
+
RightShift,
|
| 19 |
+
TransformerDecoder,
|
| 20 |
+
TransformerDecoderLayer,
|
| 21 |
+
)
|
| 22 |
+
from torchmultimodal.modules.layers.attention import SelfAttention
|
| 23 |
+
from torchmultimodal.modules.layers.position_embedding import (
|
| 24 |
+
BroadcastedPositionEmbedding,
|
| 25 |
+
)
|
| 26 |
+
from torchmultimodal.utils.common import load_module_from_url
|
| 27 |
+
from torchtext.transforms import CharBPETokenizer
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
PRETRAINED_TOKENIZER_ENCODER_URL = "https://pytorch.s3.amazonaws.com/models/multimodal/mugen/tokenizer-coinrun_1024_encoder.json"
|
| 31 |
+
PRETRAINED_TOKENIZER_MERGES_URL = "https://pytorch.s3.amazonaws.com/models/multimodal/mugen/tokenizer-coinrun_1024_merges.txt"
|
| 32 |
+
PRETRAINED_TEXT_VIDEO_GPT_URL_MAPPING = {
|
| 33 |
+
"mugen_L32": "https://pytorch.s3.amazonaws.com/models/multimodal/mugen/text_video_gpt_L32_weights-17db9549.pth",
|
| 34 |
+
"mugen_L16": "https://pytorch.s3.amazonaws.com/models/multimodal/mugen/text_video_gpt_L16_weights-5dfc5a0a.pth",
|
| 35 |
+
"mugen_L8": "https://pytorch.s3.amazonaws.com/models/multimodal/mugen/text_video_gpt_L8_weights-72b6d2ab.pth",
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def text_video_gpt(
|
| 40 |
+
text_seq_len: int = 128,
|
| 41 |
+
video_seq_len: int = 32,
|
| 42 |
+
resolution: int = 256,
|
| 43 |
+
downsample: Tuple[int, int, int] = (4, 32, 32),
|
| 44 |
+
d_model: int = 768,
|
| 45 |
+
n_head: int = 8,
|
| 46 |
+
dropout: float = 0.2,
|
| 47 |
+
attn_dropout: float = 0.3,
|
| 48 |
+
num_decoder_layers: int = 12,
|
| 49 |
+
use_gpt_init: bool = True,
|
| 50 |
+
pretrained_text_tokenizer_encoder_url: str = PRETRAINED_TOKENIZER_ENCODER_URL,
|
| 51 |
+
pretrained_text_tokenizer_merges_url: str = PRETRAINED_TOKENIZER_MERGES_URL,
|
| 52 |
+
pretrained_video_vqvae_model_key: Optional[str] = None,
|
| 53 |
+
pretrained_text_video_gpt_model_key: Optional[str] = None,
|
| 54 |
+
) -> MultimodalGPT:
|
| 55 |
+
"""Builds a text-to-video GPT model from user inputs
|
| 56 |
+
|
| 57 |
+
Parameter defaults follow MUGEN project:
|
| 58 |
+
* Video VQVAE: https://github.com/mugen-org/MUGEN_baseline/tree/main/generation/experiments/vqvae
|
| 59 |
+
* GPT: https://github.com/mugen-org/MUGEN_baseline/blob/main/lib/models/gpt/gpt.py#L252
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
text_seq_len (int): Length of text sequences after padding. Defaults to ``128``.
|
| 63 |
+
video_seq_len (int): Length of video sequences sampled from the dataset. Defaults to ``32``. Other
|
| 64 |
+
values used by MUGEN are ``8``, ``16``.
|
| 65 |
+
resolution (int): Resolution of the sampled video sequences defining height and width of each frame.
|
| 66 |
+
Defaults to ``256``.
|
| 67 |
+
downsample (Tuple[int, int, int]): Ratio by which to disperse along each dimension the sampled sequences.
|
| 68 |
+
For example, if the original frame is ``(32, 256, 256)``, after downsampling by ``(4, 32, 32)`` the
|
| 69 |
+
new frame will be of shape ``(8, 8, 8)`` with each dim divided by the rate of downsample. Defaults to
|
| 70 |
+
``(4, 32, 32)``.
|
| 71 |
+
d_model (int): Dimension of the underlying transformer decoder.
|
| 72 |
+
See :py:class:`torchmultimodal.models.video_gpt.gpt.TransformerDecoderLayer`. Defaults to ``768``.
|
| 73 |
+
n_head (int): Number of attention heads used by the transformer decoder. Defaults to ``8``.
|
| 74 |
+
dropout (float): Dropout probability used by the projection layer of the transformer decoder.
|
| 75 |
+
Defaults to ``0.2``.
|
| 76 |
+
attn_dropout (float): Dropout probability used by the attention layer of the transformer decoder.
|
| 77 |
+
Defaults to ``0.3``.
|
| 78 |
+
num_decoder_layers (int): Number of transformer decoder layers. Defaults to ``12``.
|
| 79 |
+
use_gpt_init (bool): Whether uses parameter initialization of GPT model. Defaults to ``True``.
|
| 80 |
+
pretrained_text_tokenizer_encoder_url (str): Remote location of the pretrained text tokenizer encoder file.
|
| 81 |
+
Defaults to `"MUGEN pretrained tokenizer encoder file
|
| 82 |
+
"<https://pytorch.s3.amazonaws.com/models/multimodal/mugen/tokenizer-coinrun_1024_encoder.json>`_.
|
| 83 |
+
pretrained_text_tokenizer_merges_url (str): Remote location of the pretrained text tokenizer merges file.
|
| 84 |
+
Defaults to `"MUGEN pretrained tokenizer merges file
|
| 85 |
+
"<https://pytorch.s3.amazonaws.com/models/multimodal/mugen/tokenizer-coinrun_1024_merges.txt>`_.
|
| 86 |
+
pretrained_video_vqvae_model_key (str, optional): Key to select the pretrained MUGEN VideoVQVAE weights
|
| 87 |
+
file. For allowed values, see :py:module:`examples/mugen/generation/video_vqvae.py`.
|
| 88 |
+
Defaults to ``None``.
|
| 89 |
+
pretrained_text_video_gpt_model_key (str, optional): Key to select the pretrained MUGEN TextVideoGPT
|
| 90 |
+
weights file. The provided key should match that of MUGEN VideoVQVAE to ensure the two models were
|
| 91 |
+
pretrained for the same video sequence length. For example ``L32`` means the video sequence length
|
| 92 |
+
is ``32``. The loaded weights will override those from the frozen VideoVQVAE model.
|
| 93 |
+
Defaults to ``None``.
|
| 94 |
+
|
| 95 |
+
Returns:
|
| 96 |
+
An instance of :py:class:`torchmultimodal.models.video_gpt.gpt.MultimodalGPT`.
|
| 97 |
+
"""
|
| 98 |
+
|
| 99 |
+
# builds text tokenizer from pre-trained
|
| 100 |
+
tokenizer = CharBPETokenizer(
|
| 101 |
+
bpe_encoder_path=pretrained_text_tokenizer_encoder_url,
|
| 102 |
+
bpe_merges_path=pretrained_text_tokenizer_merges_url,
|
| 103 |
+
unk_token="[UNK]",
|
| 104 |
+
special_tokens=["[PAD]", "[CLS]", "[SEP]", "[UNK]", "[MASK]"],
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
# builds text tokenizer
|
| 108 |
+
text_tokenizer = TextTokenizer(
|
| 109 |
+
context_len=text_seq_len,
|
| 110 |
+
d_model=d_model,
|
| 111 |
+
tokenizer=tokenizer,
|
| 112 |
+
)
|
| 113 |
+
num_text_tokens = text_tokenizer.num_text_tokens
|
| 114 |
+
|
| 115 |
+
# builds video tokenizer
|
| 116 |
+
video_vqvae = video_vqvae_mugen(
|
| 117 |
+
pretrained_model_key=pretrained_video_vqvae_model_key,
|
| 118 |
+
freeze_model=True,
|
| 119 |
+
)
|
| 120 |
+
video_vqvae.eval()
|
| 121 |
+
num_video_tokens = video_vqvae.num_embeddings # size of the codebook
|
| 122 |
+
|
| 123 |
+
# derives the expected latent shape from video input shape
|
| 124 |
+
video_input_shape = (video_seq_len, resolution, resolution)
|
| 125 |
+
video_latent_shape = latent_shape(video_input_shape, downsample)
|
| 126 |
+
video_vqvae_latent_shape = video_vqvae.latent_shape(video_input_shape)
|
| 127 |
+
# video vqvae will apply convolutions to the input shape which effectively
|
| 128 |
+
# reduces the size by ``dim//stride`` after each layer
|
| 129 |
+
# sanity check that the expected and actual latent shapes are consistent
|
| 130 |
+
if video_latent_shape != video_vqvae_latent_shape:
|
| 131 |
+
raise ValueError(
|
| 132 |
+
f"Latent shape derived from video inputs: {video_latent_shape} "
|
| 133 |
+
f"does not match that of video vqvae: {video_vqvae_latent_shape}"
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
# builds text embedding projection: text_emb is already of output shape `d_model`
|
| 137 |
+
# generally a projection layer is needed to bridge the tokenizer and
|
| 138 |
+
# `torchmultimodal.models.gpt.MultimodalTransformerDecoder`, see `video_projection`
|
| 139 |
+
text_projection = nn.Identity()
|
| 140 |
+
|
| 141 |
+
# builds video embedding projection
|
| 142 |
+
video_projection = nn.Linear(video_vqvae.embedding_dim, d_model, bias=False)
|
| 143 |
+
|
| 144 |
+
# builds multimodal decoder
|
| 145 |
+
text_pos_emb = nn.Embedding(text_seq_len, d_model)
|
| 146 |
+
video_pos_emb = BroadcastedPositionEmbedding(video_latent_shape, d_model)
|
| 147 |
+
attention_layer = SelfAttention(attn_dropout=attn_dropout)
|
| 148 |
+
decoder_layer = TransformerDecoderLayer(
|
| 149 |
+
d_model, n_head, dropout, attn_module=attention_layer
|
| 150 |
+
)
|
| 151 |
+
decoder = TransformerDecoder(decoder_layer, num_decoder_layers)
|
| 152 |
+
right_shift = RightShift(d_model)
|
| 153 |
+
mm_decoder = MultimodalTransformerDecoder(
|
| 154 |
+
text_pos_emb, video_pos_emb, decoder, right_shift
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
model = MultimodalGPT(
|
| 158 |
+
d_model=d_model,
|
| 159 |
+
num_in_tokens=num_text_tokens,
|
| 160 |
+
num_out_tokens=num_video_tokens,
|
| 161 |
+
latent_shape=video_latent_shape,
|
| 162 |
+
in_tokenizer=text_tokenizer,
|
| 163 |
+
out_tokenizer=video_vqvae,
|
| 164 |
+
mm_decoder=mm_decoder,
|
| 165 |
+
in_projection=text_projection,
|
| 166 |
+
out_projection=video_projection,
|
| 167 |
+
use_gpt_init=use_gpt_init,
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
if pretrained_text_video_gpt_model_key is not None:
|
| 171 |
+
if (
|
| 172 |
+
pretrained_text_video_gpt_model_key
|
| 173 |
+
not in PRETRAINED_TEXT_VIDEO_GPT_URL_MAPPING
|
| 174 |
+
):
|
| 175 |
+
raise KeyError(
|
| 176 |
+
f"Invalid pretrained model key: {pretrained_text_video_gpt_model_key}"
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
load_module_from_url(
|
| 180 |
+
model,
|
| 181 |
+
PRETRAINED_TEXT_VIDEO_GPT_URL_MAPPING[pretrained_text_video_gpt_model_key],
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
return model
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def latent_shape(
|
| 188 |
+
input_shape: Tuple[int, ...], downsample: Tuple[int, ...]
|
| 189 |
+
) -> Tuple[int, ...]:
|
| 190 |
+
"""Derives latent shape of video inputs after VQ-VAE encoding"""
|
| 191 |
+
return tuple([s // d for s, d in zip(input_shape, downsample)])
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
class TextTokenizer(nn.Module):
|
| 195 |
+
"""Converts between text and tokens / embedings
|
| 196 |
+
|
| 197 |
+
Wrapper around the tokenizer to be consistent with the API required by
|
| 198 |
+
:py:class:`torchmultimodal.models.video_gpt.gpt.MultimodalGPT`. It also contains the
|
| 199 |
+
embedding layer to enable lookup by token ids.
|
| 200 |
+
"""
|
| 201 |
+
|
| 202 |
+
def __init__(
|
| 203 |
+
self,
|
| 204 |
+
context_len: int,
|
| 205 |
+
d_model: int,
|
| 206 |
+
tokenizer: nn.Module,
|
| 207 |
+
) -> None:
|
| 208 |
+
super().__init__()
|
| 209 |
+
self.tokenizer = tokenizer
|
| 210 |
+
self.pad_id = self.tokenizer.encode("[PAD]")[0] # type: ignore
|
| 211 |
+
self.vocab_size = self.tokenizer.vocab_size # type: ignore
|
| 212 |
+
self.context_len = context_len
|
| 213 |
+
# MUGEN treats padding as unique ids so adding them to the total text tokens
|
| 214 |
+
# https://github.com/mugen-org/MUGEN_baseline/blob/main/lib/models/gpt/gpt.py#L44
|
| 215 |
+
self.num_text_tokens = self.vocab_size + context_len
|
| 216 |
+
self.embedding = nn.Embedding(self.num_text_tokens, d_model)
|
| 217 |
+
|
| 218 |
+
def text_to_tokens(self, sentences: List[str]) -> Tensor:
|
| 219 |
+
"""Pads the sentences to be of equal lengths"""
|
| 220 |
+
tokens = [
|
| 221 |
+
self.tokenizer.encode(sentence.strip().lower() + " [SEP]") # type: ignore
|
| 222 |
+
for sentence in sentences
|
| 223 |
+
]
|
| 224 |
+
token_ids = [t[: self.context_len] for t in tokens]
|
| 225 |
+
# pad each sentence to be of length `context_len`
|
| 226 |
+
for i, t in enumerate(token_ids):
|
| 227 |
+
t += [self.pad_id] * (self.context_len - len(t))
|
| 228 |
+
token_ids[i] = t
|
| 229 |
+
|
| 230 |
+
return torch.Tensor(token_ids).type(torch.int64)
|
| 231 |
+
|
| 232 |
+
def encode(self, sentences: List[str], device: str) -> Tensor:
|
| 233 |
+
"""Encodes sentences to token ids"""
|
| 234 |
+
token_ids = self.text_to_tokens(sentences).to(device)
|
| 235 |
+
# bump padding token ids by vocab_size so that they do not coincide with un-padded token ids
|
| 236 |
+
# and that the padding token ids themselves are unique
|
| 237 |
+
unique_pad_ids = torch.arange(self.context_len, device=device) + self.vocab_size
|
| 238 |
+
token_ids = torch.where(token_ids == self.pad_id, unique_pad_ids, token_ids)
|
| 239 |
+
return token_ids
|
| 240 |
+
|
| 241 |
+
def _filter_token_ids(self, token_ids: List[int]) -> List[Optional[int]]:
|
| 242 |
+
"""Filters out token ids out side of vocab"""
|
| 243 |
+
return [
|
| 244 |
+
token_id
|
| 245 |
+
for token_id in token_ids
|
| 246 |
+
if token_id > 0 and token_id <= self.vocab_size
|
| 247 |
+
]
|
| 248 |
+
|
| 249 |
+
def decode(self, token_ids: Tensor) -> List[str]:
|
| 250 |
+
"""Decodes token ids back to sentences"""
|
| 251 |
+
sentences = []
|
| 252 |
+
for _token_ids in token_ids: # iterate over batches
|
| 253 |
+
_token_ids = self._filter_token_ids(_token_ids.tolist())
|
| 254 |
+
sentence = self.tokenizer.decode(_token_ids) # type: ignore
|
| 255 |
+
sentences.append(sentence)
|
| 256 |
+
|
| 257 |
+
return sentences
|
| 258 |
+
|
| 259 |
+
def lookup(self, token_ids: Tensor) -> Tensor:
|
| 260 |
+
return self.embedding(token_ids)
|
multimodal/examples/mugen/generation/video_vqvae.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from typing import Optional
|
| 8 |
+
|
| 9 |
+
from torchmultimodal.models.video_gpt.video_vqvae import (
|
| 10 |
+
preprocess_int_conv_params,
|
| 11 |
+
VideoDecoder,
|
| 12 |
+
VideoEncoder,
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
from torchmultimodal.models.vqvae import VQVAE
|
| 16 |
+
from torchmultimodal.utils.common import load_module_from_url, remove_grad
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
MUGEN_PRETRAINED_MAPPING = {
|
| 20 |
+
"mugen_L32": "https://download.pytorch.org/models/multimodal/mugen/mugen_video_vqvae_L32.pt",
|
| 21 |
+
"mugen_L16": "https://download.pytorch.org/models/multimodal/mugen/mugen_video_vqvae_L16.pt",
|
| 22 |
+
"mugen_L8": "https://download.pytorch.org/models/multimodal/mugen/mugen_video_vqvae_L8.pt",
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def video_vqvae_mugen(
|
| 27 |
+
in_channel_dim: int = 3,
|
| 28 |
+
encoder_hidden_dim: int = 240,
|
| 29 |
+
encoder_kernel_size: int = 3,
|
| 30 |
+
n_res_layers: int = 4,
|
| 31 |
+
attn_hidden_dim: int = 240,
|
| 32 |
+
num_embeddings: int = 2048,
|
| 33 |
+
embedding_dim: int = 256,
|
| 34 |
+
decoder_hidden_dim: int = 240,
|
| 35 |
+
decoder_kernel_size: int = 3,
|
| 36 |
+
pretrained_model_key: Optional[str] = None,
|
| 37 |
+
freeze_model: bool = False,
|
| 38 |
+
) -> VQVAE:
|
| 39 |
+
"""Constructor for MUGEN's Video VQVAE. Expects input video data of shape ``{8,16,32}x256x256``.
|
| 40 |
+
Trained for tokenization of video data and use in video-audio-text retrieval and generation tasks.
|
| 41 |
+
See Hayes et al. 2022 for more details: https://arxiv.org/pdf/2204.08058.pdf
|
| 42 |
+
Code ref:
|
| 43 |
+
https://github.com/mugen-org/MUGEN_baseline/blob/main/lib/models/video_vqvae/vqvae.py
|
| 44 |
+
https://github.com/mugen-org/MUGEN_baseline/blob/main/generation/experiments/vqvae/VideoVQVAE_L32.sh
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
in_channel_dim (int, optional): Size of channel dim in input. Defaults to ``3``.
|
| 48 |
+
encoder_hidden_dim (int, optional): Size of channel dims in encoder conv layers. Defaults to ``240``.
|
| 49 |
+
encoder_kernel_size (int, optional): Kernel size for encoder. Defaults to ``3``.
|
| 50 |
+
n_res_layers (int, optional): Number of ``AttentionResidualBlocks`` to include in encoder and decoder.
|
| 51 |
+
Defaults to ``4``.
|
| 52 |
+
attn_hidden_dim (int, optional): Size of hidden dim of
|
| 53 |
+
:class:`~torchmultimodal.models.video_gpt.video_vqvae.AttentionResidualBlocks`. Defaults to ``240``.
|
| 54 |
+
num_embeddings (int, optional): Number of embedding vectors used in
|
| 55 |
+
:class:`~torchmultimodal.modules.layers.codebook.Codebook`. Defaults to ``2048``.
|
| 56 |
+
embedding_dim (int, optional): Dimensionality of embedding vectors in
|
| 57 |
+
:class:`~torchmultimodal.modules.layers.codebook.Codebook`. Defaults to ``256``.
|
| 58 |
+
decoder_hidden_dim (int, optional): Size of channel dims in decoder conv tranpose layers.
|
| 59 |
+
Defaults to ``240``.
|
| 60 |
+
decoder_kernel_size (int, optional): Kernel size for decoder. Defaults to ``3``.
|
| 61 |
+
pretrained_model_key (str, optional): Load a specified MUGEN VQVAE checkpoint.
|
| 62 |
+
freeze_model (bool): Whether to freeze the weights of the pretrained model. Defaults to ``False``.
|
| 63 |
+
|
| 64 |
+
Returns:
|
| 65 |
+
An instance of :class:`~torchmultimodal.models.vqvae.VQVAE` constructed with:
|
| 66 |
+
* :class:`~torchmultimodal.model.video_gpt.video_vqvae.VideoEncoder`
|
| 67 |
+
* :class:`~torchmultimodal.model.video_gpt.video_vqvae.VideoDecoder`
|
| 68 |
+
"""
|
| 69 |
+
encoder_strides = ((2, 2, 2), (2, 2, 2), (1, 2, 2), (1, 2, 2), (1, 2, 2), (1, 1, 1))
|
| 70 |
+
decoder_strides = ((2, 2, 2), (2, 2, 2), (1, 2, 2), (1, 2, 2), (1, 2, 2))
|
| 71 |
+
encoder_n_layers = len(encoder_strides)
|
| 72 |
+
decoder_n_layers = len(decoder_strides)
|
| 73 |
+
encoder_in_channel_dims = (in_channel_dim,) + (encoder_hidden_dim,) * max(
|
| 74 |
+
encoder_n_layers - 1, 0
|
| 75 |
+
)
|
| 76 |
+
decoder_out_channel_dims = (decoder_hidden_dim,) * max(decoder_n_layers - 1, 0) + (
|
| 77 |
+
in_channel_dim,
|
| 78 |
+
)
|
| 79 |
+
encoder_kernel_sizes_fixed = preprocess_int_conv_params(
|
| 80 |
+
encoder_in_channel_dims, encoder_kernel_size
|
| 81 |
+
)
|
| 82 |
+
decoder_kernel_sizes_fixed = preprocess_int_conv_params(
|
| 83 |
+
decoder_out_channel_dims, decoder_kernel_size
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
encoder = VideoEncoder(
|
| 87 |
+
encoder_in_channel_dims,
|
| 88 |
+
encoder_kernel_sizes_fixed,
|
| 89 |
+
encoder_strides,
|
| 90 |
+
embedding_dim,
|
| 91 |
+
n_res_layers,
|
| 92 |
+
attn_hidden_dim,
|
| 93 |
+
)
|
| 94 |
+
decoder = VideoDecoder(
|
| 95 |
+
decoder_out_channel_dims,
|
| 96 |
+
decoder_kernel_sizes_fixed,
|
| 97 |
+
decoder_strides,
|
| 98 |
+
embedding_dim,
|
| 99 |
+
n_res_layers,
|
| 100 |
+
attn_hidden_dim,
|
| 101 |
+
)
|
| 102 |
+
model = VQVAE(encoder, decoder, num_embeddings, embedding_dim)
|
| 103 |
+
|
| 104 |
+
if pretrained_model_key is not None:
|
| 105 |
+
if pretrained_model_key not in MUGEN_PRETRAINED_MAPPING.keys():
|
| 106 |
+
raise KeyError(f"Invalid pretrained model key: {pretrained_model_key}")
|
| 107 |
+
|
| 108 |
+
load_module_from_url(model, MUGEN_PRETRAINED_MAPPING[pretrained_model_key])
|
| 109 |
+
|
| 110 |
+
if freeze_model:
|
| 111 |
+
remove_grad(model)
|
| 112 |
+
|
| 113 |
+
return model
|
multimodal/examples/mugen/retrieval/README.md
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# MUGEN Retrieval
|
| 2 |
+
|
| 3 |
+
This directory contains reference training and evaluation scripts for MUGEN's video-text retrieval model, including a tutorial notebook for the model usage [Colab](https://colab.research.google.com/drive/1gZfz1jsy79CNCK9t2_r43yt3z7v-w4HS?usp=sharing) or [GitHub](https://github.com/facebookresearch/multimodal/blob/main/examples/mugen/retrieval/tutorial.ipynb).
|
| 4 |
+
|
| 5 |
+
## Model
|
| 6 |
+
MUGEN's video-text retrieval model follows from [VideoCLIP](https://arxiv.org/abs/2109.14084), a contrastive model for video and text.
|
| 7 |
+
|
| 8 |
+
The name "VideoCLIP" refers to its similarities to OpenAI's [CLIP](https://arxiv.org/abs/2103.00020), which was originally proposed for zero-shot learning of image classification tasks by “drawing cues” from text data with the corresponding visual concepts. Unlike various predecessor models based on supervised learning, CLIP does not have to be trained on the task-specific datasets or fine-tuned with a task-specific head. The model learns a joint embedding space for both image and text data and optimizes a scaled cosine similarity function between the image and text embedding vectors. The loss function is the sum of the normalized cosine similarities for every pair of image-and-text samples. Each embedding is trained with a unimodal encoder, e.g., a transformer for text, vision transformer (ViT) or ResNet for image.
|
| 9 |
+
|
| 10 |
+
The VideoCLIP model follows the CLIP architecture but replaces the image encoder with a video encoder. VideoCLIP's video encoder is backed by [Separable 3D CNN (S3D)](https://arxiv.org/abs/1712.04851), a video classification model, and the text encoder is backed by [DistilBERT](https://arxiv.org/abs/1910.01108), a lightweight transformer for language modeling.
|
| 11 |
+
|
| 12 |
+
## Training
|
| 13 |
+
The configurable parameters for training can be found in `configs/train.yaml`. Note that the training script supports training on 1 or more devices on a single node. Then run the following command:
|
| 14 |
+
```
|
| 15 |
+
python train.py config=configs/train.yaml
|
| 16 |
+
```
|
| 17 |
+
A checkpoint file with the best-performing weights will be saved under `{default_root_dir}/lightning_logs/`, where `default_root_dir` is specified in the training config. If `default_root_dir` is `null`, then it will act as your working directory.
|
| 18 |
+
|
| 19 |
+
## Evaluation
|
| 20 |
+
The configurable parameters for evaluation can be found in `configs/eval.yaml`. You can choose to replace `checkpoint_path` with the path to your checkpoint from the training step, or keep the default `checkpoint_path` to load the MUGEN authors' weights (fit to our implementation). Then run the following command:
|
| 21 |
+
```
|
| 22 |
+
python eval.py config=configs/eval.yaml
|
| 23 |
+
```
|
| 24 |
+
|
| 25 |
+
Using the default arguments in `configs/eval.yaml` (including the MUGEN authors' published weights), we ran the evaluation script on the full MUGEN test set and got the following results:
|
| 26 |
+
|
| 27 |
+
| Metric (%) | MUGEN Results | TorchMultimodal Results |
|
| 28 |
+
| ----------- | ----------- | ----------- |
|
| 29 |
+
| Text2video top-1 recall | 8.54 | 8.26 |
|
| 30 |
+
| Text2video top-5 recall | 22.50 | 22.34 |
|
| 31 |
+
| Text2video top-10 recall | 31.71 | 31.68 |
|
| 32 |
+
| Video2text top-1 recall | 10.61 | 10.79 |
|
| 33 |
+
| Video2text top-5 recall | 25.72 | 25.70 |
|
| 34 |
+
| Video2text top-10 recall | 34.70 | 34.60 |
|
multimodal/examples/mugen/retrieval/configs/eval.yaml
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_target_: examples.mugen.retrieval.definitions.EvaluationArgs
|
| 2 |
+
dataset_args:
|
| 3 |
+
_target_: examples.mugen.data.mugen_dataset.MUGENDatasetArgs
|
| 4 |
+
data_path: "datasets/coinrun/coinrun_dataset_jsons/release"
|
| 5 |
+
asset_path: "datasets/coinrun/assets"
|
| 6 |
+
sample_every_n_frames: 3
|
| 7 |
+
sequence_length: 32
|
| 8 |
+
audio_sample_rate: 22050
|
| 9 |
+
audio_sample_length: 70560
|
| 10 |
+
resolution: 256
|
| 11 |
+
bbox_smap_for_agent: False
|
| 12 |
+
bbox_smap_for_monsters: False
|
| 13 |
+
use_manual_annotation: True
|
| 14 |
+
use_auto_annotation: False
|
| 15 |
+
use_downsampled_trainset: False
|
| 16 |
+
fixed_start_idx: False
|
| 17 |
+
get_game_frame: True
|
| 18 |
+
get_seg_map: False
|
| 19 |
+
get_text_desc: True
|
| 20 |
+
get_audio: False
|
| 21 |
+
debug: False
|
| 22 |
+
datamodule_args:
|
| 23 |
+
_target_: examples.mugen.retrieval.definitions.DataModuleArgs
|
| 24 |
+
batch_size: 16
|
| 25 |
+
num_workers: 4
|
| 26 |
+
shuffle: False
|
| 27 |
+
bert_text_transform:
|
| 28 |
+
_target_: examples.mugen.retrieval.definitions.BertTextTransformArgs
|
| 29 |
+
video_transform:
|
| 30 |
+
_target_: examples.mugen.retrieval.definitions.VideoTransformArgs
|
| 31 |
+
lightningmodule_args:
|
| 32 |
+
_target_: examples.mugen.retrieval.definitions.LightningModuleArgs
|
| 33 |
+
logit_scale: 0.07
|
| 34 |
+
logit_scale_max: 100.0
|
| 35 |
+
videoclip_args:
|
| 36 |
+
_target_: examples.mugen.retrieval.definitions.VideoCLIPArgs
|
| 37 |
+
text_pretrained: False
|
| 38 |
+
text_trainable: False
|
| 39 |
+
text_model_name: "distilbert-base-uncased"
|
| 40 |
+
text_model_config: null
|
| 41 |
+
text_padding_value: 0
|
| 42 |
+
video_pretrained: False
|
| 43 |
+
video_trainable: False
|
| 44 |
+
video_pretrain_path: "https://pytorch.s3.amazonaws.com/models/multimodal/mugen/S3D_kinetics400.pt"
|
| 45 |
+
proj_out_dim: 256
|
| 46 |
+
proj_dropout: 0.1
|
| 47 |
+
checkpoint_path: "https://pytorch.s3.amazonaws.com/models/multimodal/mugen/videoclip_lightning_mugen.pt"
|
| 48 |
+
accelerator: "auto"
|
multimodal/examples/mugen/retrieval/configs/train.yaml
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_target_: examples.mugen.retrieval.definitions.TrainingArgs
|
| 2 |
+
dataset_args:
|
| 3 |
+
_target_: examples.mugen.data.mugen_dataset.MUGENDatasetArgs
|
| 4 |
+
data_path: "datasets/coinrun/coinrun_dataset_jsons/release"
|
| 5 |
+
asset_path: "datasets/coinrun/assets"
|
| 6 |
+
sample_every_n_frames: 3
|
| 7 |
+
sequence_length: 32
|
| 8 |
+
audio_sample_rate: 22050
|
| 9 |
+
audio_sample_length: 70560
|
| 10 |
+
resolution: 224
|
| 11 |
+
bbox_smap_for_agent: False
|
| 12 |
+
bbox_smap_for_monsters: False
|
| 13 |
+
use_manual_annotation: True
|
| 14 |
+
use_auto_annotation: False
|
| 15 |
+
use_downsampled_trainset: False
|
| 16 |
+
fixed_start_idx: False
|
| 17 |
+
get_game_frame: True
|
| 18 |
+
get_seg_map: False
|
| 19 |
+
get_text_desc: True
|
| 20 |
+
get_audio: False
|
| 21 |
+
debug: False
|
| 22 |
+
datamodule_args:
|
| 23 |
+
_target_: examples.mugen.retrieval.definitions.DataModuleArgs
|
| 24 |
+
batch_size: 16
|
| 25 |
+
num_workers: 4
|
| 26 |
+
shuffle: False
|
| 27 |
+
bert_text_transform:
|
| 28 |
+
_target_: examples.mugen.retrieval.definitions.BertTextTransformArgs
|
| 29 |
+
video_transform:
|
| 30 |
+
_target_: examples.mugen.retrieval.definitions.VideoTransformArgs
|
| 31 |
+
lightningmodule_args:
|
| 32 |
+
_target_: examples.mugen.retrieval.definitions.LightningModuleArgs
|
| 33 |
+
logit_scale: 0.07
|
| 34 |
+
logit_scale_max: 100.0
|
| 35 |
+
learning_rate: 0.001
|
| 36 |
+
weight_decay: 0.001
|
| 37 |
+
videoclip_args:
|
| 38 |
+
_target_: examples.mugen.retrieval.definitions.VideoCLIPArgs
|
| 39 |
+
text_pretrained: True
|
| 40 |
+
text_trainable: False
|
| 41 |
+
text_model_name: "distilbert-base-uncased"
|
| 42 |
+
text_model_config: null
|
| 43 |
+
text_padding_value: 0
|
| 44 |
+
video_pretrained: True
|
| 45 |
+
video_trainable: True
|
| 46 |
+
video_pretrain_path: "https://pytorch.s3.amazonaws.com/models/multimodal/mugen/S3D_kinetics400.pt"
|
| 47 |
+
proj_out_dim: 256
|
| 48 |
+
proj_dropout: 0.1
|
| 49 |
+
accelerator: "auto"
|
| 50 |
+
devices: 4
|
| 51 |
+
max_epochs: 20
|
| 52 |
+
log_every_n_steps: 100
|
| 53 |
+
default_root_dir: null
|
multimodal/examples/mugen/retrieval/definitions.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
from typing import Any, Dict, Optional, Tuple
|
| 9 |
+
|
| 10 |
+
from examples.mugen.data.mugen_dataset import MUGENDatasetArgs
|
| 11 |
+
|
| 12 |
+
from torchmultimodal.transforms.video_transform import (
|
| 13 |
+
DEFAULT_MEAN,
|
| 14 |
+
DEFAULT_RESIZE_SHAPE,
|
| 15 |
+
DEFAULT_STD,
|
| 16 |
+
MUGEN_DEFAULT_TIME_SAMPLES,
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@dataclass
|
| 21 |
+
class BertTextTransformArgs:
|
| 22 |
+
vocab_file: str = "https://huggingface.co/bert-base-uncased/resolve/main/vocab.txt"
|
| 23 |
+
do_lower_case: bool = True
|
| 24 |
+
start_token: int = 101
|
| 25 |
+
end_token: int = 102
|
| 26 |
+
padding_value: int = 0
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@dataclass
|
| 30 |
+
class VideoTransformArgs:
|
| 31 |
+
time_samples: int = MUGEN_DEFAULT_TIME_SAMPLES
|
| 32 |
+
mean: Tuple[float] = DEFAULT_MEAN
|
| 33 |
+
std: Tuple[float] = DEFAULT_STD
|
| 34 |
+
resize_shape: Tuple[int, int] = DEFAULT_RESIZE_SHAPE
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
@dataclass
|
| 38 |
+
class DataModuleArgs:
|
| 39 |
+
batch_size: int = 16
|
| 40 |
+
num_workers: int = 4
|
| 41 |
+
shuffle: bool = False
|
| 42 |
+
bert_text_transform: BertTextTransformArgs = BertTextTransformArgs()
|
| 43 |
+
video_transform: VideoTransformArgs = VideoTransformArgs()
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
@dataclass
|
| 47 |
+
class LightningModuleArgs:
|
| 48 |
+
logit_scale: float = 0.07
|
| 49 |
+
logit_scale_max: float = 100.0
|
| 50 |
+
learning_rate: float = 1e-3
|
| 51 |
+
weight_decay: float = 1e-3
|
| 52 |
+
recall_ks: Tuple[int] = (1, 5, 10)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
@dataclass
|
| 56 |
+
class VideoCLIPArgs:
|
| 57 |
+
text_pretrained: bool = False
|
| 58 |
+
text_trainable: bool = False
|
| 59 |
+
text_model_name: str = "distilbert-base-uncased"
|
| 60 |
+
text_model_config: Optional[Dict[str, Any]] = None
|
| 61 |
+
text_padding_value: int = 0
|
| 62 |
+
video_pretrained: bool = False
|
| 63 |
+
video_trainable: bool = False
|
| 64 |
+
video_pretrain_path: str = (
|
| 65 |
+
"https://pytorch.s3.amazonaws.com/models/multimodal/mugen/S3D_kinetics400.pt"
|
| 66 |
+
)
|
| 67 |
+
proj_out_dim: int = 256
|
| 68 |
+
proj_dropout: float = 0.1
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
@dataclass
|
| 72 |
+
class EvaluationArgs:
|
| 73 |
+
dataset_args: MUGENDatasetArgs = MUGENDatasetArgs(
|
| 74 |
+
get_game_frame=True,
|
| 75 |
+
get_text_desc=True,
|
| 76 |
+
resolution=256,
|
| 77 |
+
fixed_start_idx=False,
|
| 78 |
+
use_manual_annotation=True,
|
| 79 |
+
use_auto_annotation=False,
|
| 80 |
+
)
|
| 81 |
+
datamodule_args: DataModuleArgs = DataModuleArgs()
|
| 82 |
+
lightningmodule_args: LightningModuleArgs = LightningModuleArgs()
|
| 83 |
+
videoclip_args: VideoCLIPArgs = VideoCLIPArgs()
|
| 84 |
+
checkpoint_path: str = "https://pytorch.s3.amazonaws.com/models/multimodal/mugen/videoclip_lightning_mugen.pt"
|
| 85 |
+
accelerator: str = "auto"
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
@dataclass
|
| 89 |
+
class TrainingArgs:
|
| 90 |
+
dataset_args: MUGENDatasetArgs = MUGENDatasetArgs(
|
| 91 |
+
get_game_frame=True,
|
| 92 |
+
get_text_desc=True,
|
| 93 |
+
resolution=224,
|
| 94 |
+
fixed_start_idx=False,
|
| 95 |
+
use_manual_annotation=True,
|
| 96 |
+
use_auto_annotation=False,
|
| 97 |
+
)
|
| 98 |
+
datamodule_args: DataModuleArgs = DataModuleArgs()
|
| 99 |
+
lightningmodule_args: LightningModuleArgs = LightningModuleArgs()
|
| 100 |
+
videoclip_args: VideoCLIPArgs = VideoCLIPArgs()
|
| 101 |
+
accelerator: str = "auto"
|
| 102 |
+
devices: int = 4
|
| 103 |
+
max_epochs: int = 1000
|
| 104 |
+
log_every_n_steps: int = 100
|
| 105 |
+
default_root_dir: Optional[str] = None
|
multimodal/examples/mugen/retrieval/eval.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from examples.mugen.data.bert_text_transform import BertTextTransform
|
| 8 |
+
from examples.mugen.data.mugen_datamodules import MUGENDataModule
|
| 9 |
+
from examples.mugen.data.mugen_dataset import MUGENDatasetArgs
|
| 10 |
+
from examples.mugen.retrieval.model import VideoCLIPLightningModule
|
| 11 |
+
from hydra.utils import instantiate
|
| 12 |
+
from omegaconf import OmegaConf
|
| 13 |
+
from pytorch_lightning import Trainer
|
| 14 |
+
from torchmultimodal.transforms.video_transform import VideoTransform
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def get_yaml_config():
|
| 18 |
+
cli_conf = OmegaConf.from_cli()
|
| 19 |
+
if "config" not in cli_conf:
|
| 20 |
+
raise ValueError(
|
| 21 |
+
"Please pass 'config' to specify configuration yaml file for running VideoCLIP evaluation"
|
| 22 |
+
)
|
| 23 |
+
yaml_conf = OmegaConf.load(cli_conf.config)
|
| 24 |
+
conf = instantiate(yaml_conf)
|
| 25 |
+
return conf
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def evaluate():
|
| 29 |
+
args = get_yaml_config()
|
| 30 |
+
|
| 31 |
+
dataset_args: MUGENDatasetArgs = args.dataset_args
|
| 32 |
+
datamodule = MUGENDataModule(
|
| 33 |
+
dataset_args,
|
| 34 |
+
text_transform=BertTextTransform(
|
| 35 |
+
**vars(args.datamodule_args.bert_text_transform)
|
| 36 |
+
),
|
| 37 |
+
video_transform=VideoTransform(**vars(args.datamodule_args.video_transform)),
|
| 38 |
+
batch_size=args.datamodule_args.batch_size,
|
| 39 |
+
num_workers=args.datamodule_args.num_workers,
|
| 40 |
+
shuffle=args.datamodule_args.shuffle,
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
model = VideoCLIPLightningModule.load_from_checkpoint(
|
| 44 |
+
args.checkpoint_path,
|
| 45 |
+
**vars(args.lightningmodule_args),
|
| 46 |
+
**vars(args.videoclip_args),
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
trainer = Trainer(accelerator=args.accelerator, devices=1)
|
| 50 |
+
trainer.test(model, dataloaders=datamodule.test_dataloader())
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
if __name__ == "__main__":
|
| 54 |
+
evaluate()
|
multimodal/examples/mugen/retrieval/model.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import warnings
|
| 8 |
+
from typing import Any, Tuple
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
from examples.mugen.retrieval.video_clip import videoclip
|
| 13 |
+
from pytorch_lightning import LightningModule
|
| 14 |
+
from torchmetrics import Recall
|
| 15 |
+
|
| 16 |
+
from torchmultimodal.modules.losses.contrastive_loss_with_temperature import (
|
| 17 |
+
ContrastiveLossWithTemperature,
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class VideoCLIPLightningModule(LightningModule):
|
| 22 |
+
"""PyTorch Lightning module for evaluating VideoCLIP model.
|
| 23 |
+
Args:
|
| 24 |
+
logit_scale (float): Initial log-temperature value for contrastive loss funtion.
|
| 25 |
+
Defaults to ``0.07``, MUGEN's log-temperature value at initialization.
|
| 26 |
+
logit_scale_max (float): Maximum log-temperature value for contrastive loss function.
|
| 27 |
+
Defaults to ``100``, MUGEN's maximum log-temperature value.
|
| 28 |
+
learning_rate (float): optimizer learning rate.
|
| 29 |
+
Defaults to ``1e-3``, MUGEN's learning rate.
|
| 30 |
+
weight_decay (float): optimizer weight decay.
|
| 31 |
+
Defaults to ``1e-3``, MUGEN's weight decay.
|
| 32 |
+
recall_ks (Tuple[int]): tuple of top-``k``'s for calculating recall.
|
| 33 |
+
Defaults to ``(1, 5, 10)``, i.e. top-1 recall, top-5 recall, and top-10 recall.
|
| 34 |
+
**videoclip_kwargs (Any): Keyword arguments for the videoCLIP model builder.
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
def __init__(
|
| 38 |
+
self,
|
| 39 |
+
logit_scale: float = 0.07,
|
| 40 |
+
logit_scale_max: float = 100,
|
| 41 |
+
learning_rate: float = 1e-3,
|
| 42 |
+
weight_decay: float = 1e-3,
|
| 43 |
+
recall_ks: Tuple[int] = (1, 5, 10),
|
| 44 |
+
**videoclip_kwargs: Any,
|
| 45 |
+
):
|
| 46 |
+
super().__init__()
|
| 47 |
+
self.model = videoclip(**videoclip_kwargs)
|
| 48 |
+
self.contrastive_loss = ContrastiveLossWithTemperature(
|
| 49 |
+
logit_scale=logit_scale,
|
| 50 |
+
logit_scale_min=None,
|
| 51 |
+
logit_scale_max=logit_scale_max,
|
| 52 |
+
)
|
| 53 |
+
self.lr = learning_rate
|
| 54 |
+
self.weight_decay = weight_decay
|
| 55 |
+
|
| 56 |
+
self.recall_ks = set(recall_ks)
|
| 57 |
+
if len(self.recall_ks) != len(recall_ks):
|
| 58 |
+
warnings.warn("Duplicate `k` values in `recall_ks` are ignored.")
|
| 59 |
+
self.metrics = torch.nn.ModuleDict()
|
| 60 |
+
for k in self.recall_ks:
|
| 61 |
+
self.metrics.update(
|
| 62 |
+
{f"v2t_recall_{k}": Recall(top_k=k), f"t2v_recall_{k}": Recall(top_k=k)}
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
def _collect_embeddings(self, outputs):
|
| 66 |
+
text_embeddings = [batch.embeddings_a for batch in outputs]
|
| 67 |
+
video_embeddings = [batch.embeddings_b for batch in outputs]
|
| 68 |
+
|
| 69 |
+
embeddings = {
|
| 70 |
+
"text": torch.cat(text_embeddings),
|
| 71 |
+
"video": torch.cat(video_embeddings),
|
| 72 |
+
}
|
| 73 |
+
return embeddings
|
| 74 |
+
|
| 75 |
+
def _compute_recall(self, split, text_embedding, video_embedding):
|
| 76 |
+
similarity_matrix = text_embedding @ video_embedding.T
|
| 77 |
+
num_samples = similarity_matrix.shape[0]
|
| 78 |
+
target_matrix = torch.eye(
|
| 79 |
+
n=num_samples, dtype=int, device=similarity_matrix.device
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
for k in self.recall_ks:
|
| 83 |
+
v2t_recall = self.metrics[f"v2t_recall_{k}"]
|
| 84 |
+
v2t_recall(preds=similarity_matrix.T, target=target_matrix)
|
| 85 |
+
self.log(f"{split}/Recall@{k} (video query, text retrieval)", v2t_recall)
|
| 86 |
+
|
| 87 |
+
t2v_recall = self.metrics[f"t2v_recall_{k}"]
|
| 88 |
+
t2v_recall(preds=similarity_matrix, target=target_matrix)
|
| 89 |
+
self.log(f"{split}/Recall@{k} (text query, video retrieval)", t2v_recall)
|
| 90 |
+
|
| 91 |
+
def configure_optimizers(self):
|
| 92 |
+
params = self.parameters()
|
| 93 |
+
optimizer = torch.optim.AdamW(
|
| 94 |
+
params, lr=self.lr, weight_decay=self.weight_decay
|
| 95 |
+
)
|
| 96 |
+
return optimizer
|
| 97 |
+
|
| 98 |
+
def training_step(self, batch, batch_idx):
|
| 99 |
+
text, video = batch.get("text"), batch.get("video")
|
| 100 |
+
model_output = self.model(features_a=text, features_b=video)
|
| 101 |
+
loss = self.contrastive_loss(
|
| 102 |
+
model_output.embeddings_a, model_output.embeddings_b
|
| 103 |
+
)
|
| 104 |
+
self.log(
|
| 105 |
+
"train/loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True
|
| 106 |
+
)
|
| 107 |
+
return {"loss": loss, "model_output": model_output}
|
| 108 |
+
|
| 109 |
+
def validation_step(self, batch, batch_idx):
|
| 110 |
+
text, video = batch.get("text"), batch.get("video")
|
| 111 |
+
model_output = self.model(features_a=text, features_b=video)
|
| 112 |
+
loss = self.contrastive_loss(
|
| 113 |
+
model_output.embeddings_a, model_output.embeddings_b
|
| 114 |
+
)
|
| 115 |
+
self.log(
|
| 116 |
+
"validation/loss",
|
| 117 |
+
loss,
|
| 118 |
+
on_step=True,
|
| 119 |
+
on_epoch=True,
|
| 120 |
+
prog_bar=True,
|
| 121 |
+
logger=True,
|
| 122 |
+
)
|
| 123 |
+
return {"loss": loss, "model_output": model_output}
|
| 124 |
+
|
| 125 |
+
def validation_epoch_end(self, outputs):
|
| 126 |
+
model_outputs = [batch["model_output"] for batch in outputs]
|
| 127 |
+
all_embeddings = self._collect_embeddings(model_outputs)
|
| 128 |
+
text_embedding, video_embedding = (
|
| 129 |
+
all_embeddings["text"],
|
| 130 |
+
all_embeddings["video"],
|
| 131 |
+
)
|
| 132 |
+
self._compute_recall("validation", text_embedding, video_embedding)
|
| 133 |
+
|
| 134 |
+
def test_step(self, batch, batch_idx):
|
| 135 |
+
text, video = batch.get("text"), batch.get("video")
|
| 136 |
+
model_output = self.model(features_a=text, features_b=video)
|
| 137 |
+
return model_output
|
| 138 |
+
|
| 139 |
+
def test_epoch_end(self, outputs):
|
| 140 |
+
all_embeddings = self._collect_embeddings(outputs)
|
| 141 |
+
text_embedding, video_embedding = (
|
| 142 |
+
all_embeddings["text"],
|
| 143 |
+
all_embeddings["video"],
|
| 144 |
+
)
|
| 145 |
+
self._compute_recall("test", text_embedding, video_embedding)
|
multimodal/examples/mugen/retrieval/train.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from examples.mugen.data.bert_text_transform import BertTextTransform
|
| 8 |
+
from examples.mugen.data.mugen_datamodules import MUGENDataModule
|
| 9 |
+
from examples.mugen.data.mugen_dataset import MUGENDatasetArgs
|
| 10 |
+
from examples.mugen.retrieval.model import VideoCLIPLightningModule
|
| 11 |
+
from hydra.utils import instantiate
|
| 12 |
+
from omegaconf import OmegaConf
|
| 13 |
+
from pytorch_lightning import Trainer
|
| 14 |
+
from pytorch_lightning.callbacks import ModelCheckpoint
|
| 15 |
+
from torchmultimodal.transforms.video_transform import VideoTransform
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def get_yaml_config():
|
| 19 |
+
cli_conf = OmegaConf.from_cli()
|
| 20 |
+
if "config" not in cli_conf:
|
| 21 |
+
raise ValueError(
|
| 22 |
+
"Please pass 'config' to specify configuration yaml file for running VideoCLIP training"
|
| 23 |
+
)
|
| 24 |
+
yaml_conf = OmegaConf.load(cli_conf.config)
|
| 25 |
+
conf = instantiate(yaml_conf)
|
| 26 |
+
return conf
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def train():
|
| 30 |
+
args = get_yaml_config()
|
| 31 |
+
|
| 32 |
+
dataset_args: MUGENDatasetArgs = args.dataset_args
|
| 33 |
+
datamodule = MUGENDataModule(
|
| 34 |
+
dataset_args,
|
| 35 |
+
text_transform=BertTextTransform(
|
| 36 |
+
**vars(args.datamodule_args.bert_text_transform)
|
| 37 |
+
),
|
| 38 |
+
video_transform=VideoTransform(**vars(args.datamodule_args.video_transform)),
|
| 39 |
+
batch_size=args.datamodule_args.batch_size,
|
| 40 |
+
num_workers=args.datamodule_args.num_workers,
|
| 41 |
+
shuffle=args.datamodule_args.shuffle,
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
model = VideoCLIPLightningModule(
|
| 45 |
+
**vars(args.lightningmodule_args),
|
| 46 |
+
**vars(args.videoclip_args),
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
checkpoint_callback = ModelCheckpoint(save_top_k=-1)
|
| 50 |
+
trainer = Trainer(
|
| 51 |
+
accelerator=args.accelerator,
|
| 52 |
+
devices=args.devices,
|
| 53 |
+
strategy="ddp_find_unused_parameters_false",
|
| 54 |
+
max_epochs=args.max_epochs,
|
| 55 |
+
log_every_n_steps=args.log_every_n_steps,
|
| 56 |
+
default_root_dir=args.default_root_dir,
|
| 57 |
+
callbacks=[checkpoint_callback],
|
| 58 |
+
)
|
| 59 |
+
trainer.fit(
|
| 60 |
+
model=model,
|
| 61 |
+
train_dataloaders=datamodule.train_dataloader(),
|
| 62 |
+
val_dataloaders=datamodule.val_dataloader(),
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
if __name__ == "__main__":
|
| 67 |
+
train()
|