tuandunghcmut commited on
Commit
56323fb
·
verified ·
1 Parent(s): f0384a9

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. multimodal/examples/albef/configs/retrieval.yaml +73 -0
  2. multimodal/examples/albef/configs/vqa.yaml +78 -0
  3. multimodal/examples/albef/data/__init__.py +5 -0
  4. multimodal/examples/albef/data/retrieval_datamodule.py +188 -0
  5. multimodal/examples/albef/data/retrieval_dataset.py +149 -0
  6. multimodal/examples/albef/data/transforms.py +141 -0
  7. multimodal/examples/albef/data/vqa_datamodules.py +206 -0
  8. multimodal/examples/albef/data/vqa_dataset.py +115 -0
  9. multimodal/examples/common/data/__init__.py +7 -0
  10. multimodal/examples/common/data/multidata.py +194 -0
  11. multimodal/examples/flava/callbacks/__init__.py +7 -0
  12. multimodal/examples/flava/callbacks/multimodal_eval.py +108 -0
  13. multimodal/examples/flava/configs/finetuning/qnli.yaml +48 -0
  14. multimodal/examples/flava/configs/finetuning/rendered_sst2.yaml +37 -0
  15. multimodal/examples/flava/configs/pretraining/debug.yaml +61 -0
  16. multimodal/examples/flava/data/__init__.py +10 -0
  17. multimodal/examples/flava/data/datamodules.py +529 -0
  18. multimodal/examples/flava/data/imagenet_zeroshot_data.py +1095 -0
  19. multimodal/examples/flava/data/transforms.py +131 -0
  20. multimodal/examples/flava/data/utils.py +80 -0
  21. multimodal/examples/flava/native/README.md +43 -0
  22. multimodal/examples/flava/native/__init__.py +5 -0
  23. multimodal/examples/flava/native/configs/1.8b.yaml +79 -0
  24. multimodal/examples/flava/native/configs/10b.yaml +80 -0
  25. multimodal/examples/flava/native/configs/2.7b.yaml +79 -0
  26. multimodal/examples/flava/native/configs/4.8b.yaml +79 -0
  27. multimodal/examples/flava/native/configs/900m.yaml +79 -0
  28. multimodal/examples/flava/native/configs/pretrain_debug.yaml +63 -0
  29. multimodal/examples/flava/native/data.py +560 -0
  30. multimodal/examples/flava/native/model.py +78 -0
  31. multimodal/examples/flava/native/train.py +415 -0
  32. multimodal/examples/flava/native/utils.py +160 -0
  33. multimodal/examples/flava/notebooks/RemapFLAVACheckpoint.ipynb +172 -0
  34. multimodal/examples/flava/tools/convert_weights.py +72 -0
  35. multimodal/examples/mugen/data/README.md +10 -0
  36. multimodal/examples/mugen/data/coinrun/construct_from_json.py +756 -0
  37. multimodal/examples/mugen/data/coinrun/game.py +295 -0
  38. multimodal/examples/mugen/data/coinrun/generate_text_desc.py +435 -0
  39. multimodal/examples/mugen/data/mugen_datamodules.py +112 -0
  40. multimodal/examples/mugen/generation/LoadAndComparePretrainedVQVAE.ipynb +383 -0
  41. multimodal/examples/mugen/generation/README.md +33 -0
  42. multimodal/examples/mugen/generation/text_video_gpt.py +260 -0
  43. multimodal/examples/mugen/generation/video_vqvae.py +113 -0
  44. multimodal/examples/mugen/retrieval/README.md +34 -0
  45. multimodal/examples/mugen/retrieval/configs/eval.yaml +48 -0
  46. multimodal/examples/mugen/retrieval/configs/train.yaml +53 -0
  47. multimodal/examples/mugen/retrieval/definitions.py +105 -0
  48. multimodal/examples/mugen/retrieval/eval.py +54 -0
  49. multimodal/examples/mugen/retrieval/model.py +145 -0
  50. multimodal/examples/mugen/retrieval/train.py +67 -0
multimodal/examples/albef/configs/retrieval.yaml ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ hidden_size: &hidden_size 768
2
+ vocab_size: &vocab_size 30522
3
+ type_vocab_size: &type_vocab_size 2
4
+ max_position_embeddings: &max_position_embeddings 512
5
+ pad_token_id: &pad_token_id 0
6
+ embed_size: &embed_size 256
7
+
8
+ seed: 42
9
+ world_size: 1
10
+ device: "cuda"
11
+ dist_url: "env://"
12
+ output_path: "./examples/albef/outputs/retrieval_output.pt"
13
+
14
+ datamodule_args:
15
+ train_files: ["./examples/albef/data_files/coco_train.json"]
16
+ test_files: ["./examples/albef/data_files/coco_test.json"]
17
+ image_root: "./examples/albef/data_files/coco"
18
+ batch_size: 32
19
+ num_workers: 8
20
+
21
+ vision_encoder_args:
22
+ hidden_size: *hidden_size
23
+ image_size: 384
24
+ patch_size: 16
25
+ num_hidden_layers: 12
26
+ num_attention_heads: 12
27
+ mlp_dim: 3072
28
+ dropout: 0.0
29
+ attention_dropout: 0.0
30
+ layer_norm_eps: 1e-6
31
+
32
+ text_encoder_args:
33
+ vocab_size: *vocab_size
34
+ hidden_size: *hidden_size
35
+ type_vocab_size: *type_vocab_size
36
+ max_position_embeddings: *max_position_embeddings
37
+ pad_token_id: *pad_token_id
38
+ num_hidden_layers: 6
39
+ num_attention_heads: 12
40
+ intermediate_size: 3072
41
+ layer_norm_eps: 1e-12
42
+ dropout: 0.0
43
+
44
+ multimodal_encoder_args:
45
+ hidden_size: *hidden_size
46
+ num_hidden_layers: 6
47
+ num_attention_heads: 12
48
+ intermediate_size: 3072
49
+ layer_norm_eps: 1e-12
50
+
51
+ projection_args:
52
+ in_features: *hidden_size
53
+ out_features: *embed_size
54
+
55
+ similarity_args:
56
+ embed_size: *embed_size
57
+ queue_size: 65536
58
+ temp: 0.07
59
+
60
+ training_args:
61
+ log_every_n_steps: 100
62
+ alpha: 0.4
63
+ weight_decay: 0.02
64
+ lr: 1e-5
65
+ min_lr: 1e-6
66
+ max_epochs: 5
67
+ step_size: 100
68
+ warmup_steps: 1
69
+ checkpoint_root: "./examples/albef/checkpoints"
70
+
71
+ eval_args:
72
+ log_every_n_steps: 100
73
+ k_test: 256
multimodal/examples/albef/configs/vqa.yaml ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ hidden_size: &hidden_size 768
2
+ vocab_size: &vocab_size 30522
3
+ type_vocab_size: &type_vocab_size 2
4
+ max_position_embeddings: &max_position_embeddings 512
5
+ pad_token_id: &pad_token_id 0
6
+
7
+ seed: 42
8
+ world_size: 1
9
+ device: "cuda"
10
+ dist_url: "env://"
11
+ output_root: "./examples/albef/outputs"
12
+
13
+ datamodule_args:
14
+ train_files: ["./examples/albef/data_files/vqa_train.json", "./examples/albef/data_files/vg_qa.json", "./examples/albef/data_files/vqa_val.json"]
15
+ test_files: ["./examples/albef/data_files/vqa_test.json"]
16
+ answer_list: "./examples/albef/data_files/answer_list.json"
17
+ vqa_root: "./examples/albef/data_files/coco"
18
+ vg_root: "./examples/albef/data_files/visual_genome"
19
+ batch_size: 32
20
+ num_workers: 8
21
+
22
+ vision_encoder_args:
23
+ hidden_size: *hidden_size
24
+ image_size: 384
25
+ patch_size: 16
26
+ num_hidden_layers: 12
27
+ num_attention_heads: 12
28
+ mlp_dim: 3072
29
+ dropout: 0.0
30
+ attention_dropout: 0.0
31
+ layer_norm_eps: 1e-6
32
+
33
+ text_encoder_args:
34
+ vocab_size: *vocab_size
35
+ hidden_size: *hidden_size
36
+ type_vocab_size: *type_vocab_size
37
+ max_position_embeddings: *max_position_embeddings
38
+ pad_token_id: *pad_token_id
39
+ num_hidden_layers: 6
40
+ num_attention_heads: 12
41
+ intermediate_size: 3072
42
+ layer_norm_eps: 1e-12
43
+ dropout: 0.0
44
+
45
+ multimodal_encoder_args:
46
+ hidden_size: *hidden_size
47
+ num_hidden_layers: 6
48
+ num_attention_heads: 12
49
+ intermediate_size: 3072
50
+ layer_norm_eps: 1e-12
51
+
52
+ text_embeddings_args:
53
+ hidden_size: *hidden_size
54
+ vocab_size: *vocab_size
55
+ pad_token_id: *pad_token_id
56
+ max_position_embeddings: *max_position_embeddings
57
+ type_vocab_size: *type_vocab_size
58
+ layer_norm_eps: 1e-12
59
+
60
+ prediction_head_args:
61
+ hidden_size: *hidden_size
62
+ vocab_size: *vocab_size
63
+ layer_norm_eps: 1e-12
64
+
65
+ training_args:
66
+ log_every_n_steps: 100
67
+ alpha: 0.4
68
+ weight_decay: 0.02
69
+ lr: 2e-5
70
+ min_lr: 1e-6
71
+ max_epochs: 8
72
+ step_size: 100
73
+ warmup_steps: 4
74
+ checkpoint_root: "./examples/albef/checkpoints"
75
+
76
+ eval_args:
77
+ log_every_n_steps: 100
78
+ k_test: 128
multimodal/examples/albef/data/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
multimodal/examples/albef/data/retrieval_datamodule.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import List, Optional, Tuple
8
+
9
+ import torch
10
+ from data.retrieval_dataset import (
11
+ ImageToTextRetrievalDataset,
12
+ RetrievalTrainingDataset,
13
+ TextToImageRetrievalDataset,
14
+ )
15
+ from data.transforms import (
16
+ ALBEFTextTransform,
17
+ testing_image_transform,
18
+ training_image_transform,
19
+ )
20
+ from pytorch_lightning import LightningDataModule
21
+ from torch import Tensor
22
+ from torch.nn.utils.rnn import pad_sequence
23
+ from torch.utils.data import DataLoader, Dataset, DistributedSampler
24
+
25
+
26
+ class RetrievalDataModule(LightningDataModule):
27
+ """
28
+ The Data Module for Retrieval task.
29
+
30
+ Args:
31
+ train_files (List[str]): The paths to training json files.
32
+ test_files (List[str]): The paths to testing json files.
33
+ image_root (str): The path to image data directory.
34
+ batch_size (int): The sampling batch size.
35
+ num_workers (int): The number of workers for the distributed mode.
36
+ """
37
+
38
+ def __init__(
39
+ self,
40
+ train_files: List[str],
41
+ test_files: List[str],
42
+ image_root: str,
43
+ batch_size: int,
44
+ num_workers: int,
45
+ ) -> None:
46
+ super().__init__()
47
+ self.train_dataset = RetrievalTrainingDataset(
48
+ train_files,
49
+ image_root,
50
+ training_image_transform(),
51
+ ALBEFTextTransform(truncate=True, max_seq_len=30, add_end_token=False),
52
+ )
53
+
54
+ self.image_dataset = ImageToTextRetrievalDataset(
55
+ test_files,
56
+ image_root,
57
+ testing_image_transform(),
58
+ )
59
+
60
+ self.text_dataset = TextToImageRetrievalDataset(
61
+ test_files,
62
+ ALBEFTextTransform(
63
+ truncate=True,
64
+ pad_to_max_seq_len=True,
65
+ max_seq_len=30,
66
+ add_end_token=False,
67
+ ),
68
+ )
69
+
70
+ self.batch_size = batch_size
71
+ self.num_workers = num_workers
72
+
73
+ def _get_sampler(
74
+ self,
75
+ dataset: Dataset,
76
+ shuffle: bool,
77
+ is_distributed: bool,
78
+ num_tasks: int,
79
+ global_rank: int,
80
+ ) -> Optional[DistributedSampler]:
81
+ # do not return a sampler if is not in distributed mode
82
+ # a default RandomSampler is used in this case
83
+ if not is_distributed:
84
+ return None
85
+
86
+ return DistributedSampler(
87
+ dataset, num_replicas=num_tasks, rank=global_rank, shuffle=shuffle
88
+ )
89
+
90
+ def train_dataloader(
91
+ self,
92
+ is_distributed: bool = False,
93
+ num_tasks: int = 0,
94
+ global_rank: int = 0,
95
+ drop_last: bool = True,
96
+ ) -> DataLoader:
97
+ """
98
+ DataLoader Outputs:
99
+ images (Tensor): Tensor of shape (B, C, W, H) of image inputs.
100
+ text (Tensor): Tensor of shape (B, L) of text inputs.
101
+ text_atts (Tensor): Tensor of shape (B, L) of text attention mask.
102
+ idx (Tensor): Tensor of shape (B) of image identifiers.
103
+ """
104
+ sampler = self._get_sampler(
105
+ dataset=self.train_dataset,
106
+ shuffle=True,
107
+ is_distributed=is_distributed,
108
+ num_tasks=num_tasks,
109
+ global_rank=global_rank,
110
+ )
111
+ shuffle = sampler is None
112
+ return DataLoader(
113
+ self.train_dataset,
114
+ batch_size=self.batch_size,
115
+ num_workers=self.num_workers,
116
+ pin_memory=True,
117
+ sampler=sampler,
118
+ shuffle=shuffle,
119
+ collate_fn=retrieval_train_collate_fn,
120
+ drop_last=drop_last,
121
+ )
122
+
123
+ def image_dataloader(
124
+ self,
125
+ drop_last: bool = False,
126
+ ) -> DataLoader:
127
+ """
128
+ DataLoader Outputs:
129
+ images (Tensor): Tensor of shape (B, C, W, H) of image inputs.
130
+ """
131
+ return DataLoader(
132
+ self.image_dataset,
133
+ batch_size=self.batch_size,
134
+ num_workers=self.num_workers,
135
+ pin_memory=True,
136
+ sampler=None,
137
+ shuffle=False,
138
+ collate_fn=None,
139
+ drop_last=drop_last,
140
+ )
141
+
142
+ def text_dataloader(
143
+ self,
144
+ drop_last: bool = False,
145
+ ) -> DataLoader:
146
+ """
147
+ DataLoader Outputs:
148
+ text (Tensor): Tensor of shape (B, L) of text inputs.
149
+ text_atts (Tensor): Tensor of shape (B, L) of text attention mask.
150
+ """
151
+ return DataLoader(
152
+ self.text_dataset,
153
+ batch_size=self.batch_size,
154
+ num_workers=self.num_workers,
155
+ pin_memory=True,
156
+ sampler=None,
157
+ shuffle=False,
158
+ collate_fn=text_collate_fn,
159
+ drop_last=drop_last,
160
+ )
161
+
162
+
163
+ def retrieval_train_collate_fn(
164
+ batch: List[Tuple[Tensor, Tensor, int]],
165
+ ) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
166
+ image_list = []
167
+ text_list = []
168
+ idx_list = []
169
+ for image, text, idx in batch:
170
+ image_list.append(image)
171
+ text_list.append(text)
172
+ idx_list.append(idx)
173
+ images = torch.stack(image_list, dim=0)
174
+ text = pad_sequence(text_list, batch_first=True)
175
+ text_atts = (text != 0).type(torch.long)
176
+ idx = Tensor(idx_list).type(torch.long)
177
+ return (
178
+ images,
179
+ text,
180
+ text_atts,
181
+ idx,
182
+ )
183
+
184
+
185
+ def text_collate_fn(batch: List[Tensor]) -> Tuple[Tensor, Tensor]:
186
+ text = pad_sequence(batch, batch_first=True)
187
+ text_atts = (text != 0).type(torch.long)
188
+ return text, text_atts
multimodal/examples/albef/data/retrieval_dataset.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import json
8
+ import os
9
+ from typing import Callable, List, Tuple, Union
10
+
11
+ from PIL import Image
12
+ from torch import Tensor
13
+ from torch.utils.data import Dataset
14
+
15
+
16
+ class RetrievalTrainingDataset(Dataset):
17
+ """
18
+ Create the training dataset for Retrieval task.
19
+
20
+ Args:
21
+ ann_file (List[str]): The paths to training annotation json files.
22
+ image_root (str): The path to image data directory.
23
+ image_transform (Callable[[Image.Image], Tensor]): Image data transform.
24
+ text_transform (Callable[[Union[List[str], str]], Tensor]): Text data transform.
25
+
26
+ Dataset Outputs:
27
+ image (Tensor): Transformed image input tensor of shape (C, H, W).
28
+ caption (Tensor): Transformed text token input ids.
29
+ idx (int): The unique identifier for the image.
30
+ """
31
+
32
+ def __init__(
33
+ self,
34
+ ann_file: List[str],
35
+ image_root: str,
36
+ image_transform: Callable[[Image.Image], Tensor],
37
+ text_transform: Callable[[Union[List[str], str]], Tensor],
38
+ ) -> None:
39
+ self.ann = []
40
+ for f in ann_file:
41
+ self.ann += json.load(open(f, "r"))
42
+
43
+ self.image_root = image_root
44
+ self.image_transform = image_transform
45
+ self.text_transform = text_transform
46
+
47
+ self.idx = {} # map str image_id from dataset to int ids
48
+ i = 0
49
+ for ann in self.ann:
50
+ image_id = ann["image_id"]
51
+ if image_id not in self.idx.keys():
52
+ self.idx[image_id] = i
53
+ i += 1
54
+
55
+ def __len__(self) -> int:
56
+ return len(self.ann)
57
+
58
+ def __getitem__(self, index: int) -> Tuple[Tensor, Tensor, int]:
59
+ ann = self.ann[index]
60
+ image_path = os.path.join(self.image_root, ann["image"])
61
+ image = Image.open(image_path).convert("RGB")
62
+ image = self.image_transform(image)
63
+ caption = self.text_transform(ann["caption"])
64
+ return image, caption, self.idx[ann["image_id"]]
65
+
66
+
67
+ class ImageToTextRetrievalDataset(Dataset):
68
+ """
69
+ Create the dataset for Image-to-Text Retrieval task.
70
+
71
+ Args:
72
+ ann_file (List[str]): The paths to annotation json files.
73
+ image_root (str): The path to image data directory.
74
+ image_transform (Callable[[Image.Image], Tensor]): Image data transform.
75
+
76
+ Dataset Outputs:
77
+ image (Tensor): Transformed image input tensor of shape (C, H, W).
78
+ """
79
+
80
+ def __init__(
81
+ self,
82
+ ann_file: List[str],
83
+ image_root: str,
84
+ image_transform: Callable[[Image.Image], Tensor],
85
+ ) -> None:
86
+ self.image_root = image_root
87
+ self.image_transform = image_transform
88
+
89
+ self.ann = []
90
+ self.images = [] # paths to all images in the dataset
91
+ self.image_to_text = {} # map image ids to text ids for evaluation
92
+ for f in ann_file:
93
+ self.ann += json.load(open(f, "r"))
94
+
95
+ text_id = 0
96
+ for image_id, ann in enumerate(self.ann):
97
+ self.images.append(ann["image"])
98
+ num_text = len(ann["caption"])
99
+ self.image_to_text[image_id] = list(range(text_id, text_id + num_text))
100
+ text_id += num_text
101
+
102
+ def __len__(self) -> int:
103
+ return len(self.images)
104
+
105
+ def __getitem__(self, index: int) -> Tensor:
106
+ image_path = os.path.join(self.image_root, self.images[index])
107
+ image = Image.open(image_path).convert("RGB")
108
+ image = self.image_transform(image)
109
+ return image
110
+
111
+
112
+ class TextToImageRetrievalDataset(Dataset):
113
+ """
114
+ Create the dataset for Text-to-Image Retrieval task.
115
+
116
+ Args:
117
+ ann_file (List[str]): The paths to annotation json files.
118
+ text_transform (Callable[[Union[List[str], str]], Tensor]): Text data transform.
119
+
120
+ Dataset Outputs:
121
+ text (Tensor): Transformed text token input ids.
122
+ """
123
+
124
+ def __init__(
125
+ self,
126
+ ann_file: List[str],
127
+ text_transform: Callable[[Union[List[str], str]], Tensor],
128
+ ) -> None:
129
+ self.text_transform = text_transform
130
+
131
+ self.ann = []
132
+ self.text = [] # all text strings in the dataset
133
+ self.text_to_image = {} # map text ids to image ids for evaluation
134
+ for f in ann_file:
135
+ self.ann += json.load(open(f, "r"))
136
+
137
+ text_id = 0
138
+ for image_id, ann in enumerate(self.ann):
139
+ for caption in ann["caption"]:
140
+ self.text.append(caption)
141
+ self.text_to_image[text_id] = image_id
142
+ text_id += 1
143
+
144
+ def __len__(self) -> int:
145
+ return len(self.text)
146
+
147
+ def __getitem__(self, index: int) -> Tensor:
148
+ text = self.text_transform(self.text[index])
149
+ return text
multimodal/examples/albef/data/transforms.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import re
8
+ from typing import List, Tuple, Union
9
+
10
+ import torch
11
+
12
+ from torchtext.transforms import PadTransform, Sequential, ToTensor, Truncate
13
+ from torchvision import transforms
14
+ from transformers.models.bert.tokenization_bert import BertTokenizer
15
+
16
+ # mean and standard deviation from the ALBEF repo:
17
+ # https://github.com/salesforce/ALBEF/blob/main/dataset/__init__.py#L16
18
+ MEAN = (0.48145466, 0.4578275, 0.40821073)
19
+ STD_DEV = (0.26862954, 0.26130258, 0.27577711)
20
+
21
+
22
+ class ALBEFTextTransform:
23
+ """
24
+ Remove punctuations and trailing spaces in input text and transform it into
25
+ a Tensor of token ids using BERTTokenizer.
26
+
27
+ Args:
28
+ pretrained_tokenizer (str): Pretrained tokenizer to use.
29
+ Default: "bert-base-uncased"
30
+ do_pre_process (bool): Whether to pre-process input text.
31
+ Defaults to True.
32
+ truncate (bool): Whether to truncate input text to max_seq_length.
33
+ Defaults to False.
34
+ pad_to_max_seq_len (bool): Whether to pad the sequence to max_seq_length.
35
+ add_end_token (bool): Whether to add the end-of-sentence token.
36
+ Defaults to True.
37
+ max_seq_len (int): The max sequence length after truncating or padding.
38
+ Defaults to 25.
39
+ cls_token_id (int): Value to represent the start of each text.
40
+ Defaults to 101, Hugging Face's BERT cls token id.
41
+ sep_token_id (int): Value to represent the end of each text.
42
+ Defaults to 102, Hugging Face's BERT sep token id.
43
+ pad_token_id (int): Value with which to pad each text so that all texts are the same length.
44
+ Defaults to 0, Hugging Face's BERT pad token id.
45
+
46
+ Inputs:
47
+ text (Union[List[str], str]): Input text to transform.
48
+ """
49
+
50
+ def __init__(
51
+ self,
52
+ pretrained_tokenizer: str = "bert-base-uncased",
53
+ do_pre_process: bool = True,
54
+ truncate: bool = False,
55
+ pad_to_max_seq_len: bool = False,
56
+ add_end_token: bool = True,
57
+ max_seq_len: int = 25,
58
+ cls_token_id: int = 101,
59
+ sep_token_id: int = 102,
60
+ pad_token_id: int = 0,
61
+ ):
62
+ self.do_pre_process = do_pre_process
63
+ self.cls_token_id = cls_token_id
64
+ self.sep_token_id = sep_token_id
65
+ self.pad_token_id = pad_token_id
66
+ self.add_end_token = add_end_token
67
+
68
+ self.tokenizer = BertTokenizer.from_pretrained(pretrained_tokenizer)
69
+ self.transform = Sequential(
70
+ Truncate(max_seq_len=max_seq_len) if truncate else torch.nn.Identity(),
71
+ ToTensor(padding_value=self.pad_token_id),
72
+ (
73
+ PadTransform(max_length=max_seq_len, pad_value=self.pad_token_id)
74
+ if pad_to_max_seq_len
75
+ else torch.nn.Identity()
76
+ ),
77
+ )
78
+
79
+ def pre_process(self, text: str) -> str:
80
+ text = (
81
+ re.sub(
82
+ r"([,.'!?\"()*#:;~])",
83
+ "",
84
+ text,
85
+ )
86
+ .replace("-", " ")
87
+ .replace("/", " ")
88
+ )
89
+ text = text.rstrip(" ")
90
+
91
+ return text
92
+
93
+ def __call__(self, text: Union[List[str], str]) -> torch.Tensor:
94
+ if self.do_pre_process:
95
+ if isinstance(text, str):
96
+ text = self.pre_process(text)
97
+ else:
98
+ text = [self.pre_process(t) for t in text]
99
+ tokens = self.tokenizer(text)["input_ids"]
100
+ if not self.add_end_token and tokens[-1] == self.sep_token_id:
101
+ tokens = tokens[:-1]
102
+ input_ids = self.transform(tokens)
103
+
104
+ return input_ids
105
+
106
+
107
+ def training_image_transform(
108
+ image_size: int = 384,
109
+ scale: Tuple[float, float] = (0.5, 1.0),
110
+ image_interpolation=transforms.InterpolationMode.BICUBIC,
111
+ mean: Tuple[float, float, float] = MEAN,
112
+ std_dev: Tuple[float, float, float] = STD_DEV,
113
+ ) -> transforms.Compose:
114
+ return transforms.Compose(
115
+ [
116
+ transforms.RandomResizedCrop(
117
+ image_size, scale=scale, interpolation=image_interpolation
118
+ ),
119
+ transforms.RandomHorizontalFlip(),
120
+ transforms.RandAugment(2, 7),
121
+ transforms.ToTensor(),
122
+ transforms.Normalize(mean, std_dev),
123
+ ]
124
+ )
125
+
126
+
127
+ def testing_image_transform(
128
+ image_size: int = 384,
129
+ image_interpolation=transforms.InterpolationMode.BICUBIC,
130
+ mean: Tuple[float, float, float] = MEAN,
131
+ std_dev: Tuple[float, float, float] = STD_DEV,
132
+ ) -> transforms.Compose:
133
+ return transforms.Compose(
134
+ [
135
+ transforms.Resize(
136
+ (image_size, image_size), interpolation=image_interpolation
137
+ ),
138
+ transforms.ToTensor(),
139
+ transforms.Normalize(mean, std_dev),
140
+ ]
141
+ )
multimodal/examples/albef/data/vqa_datamodules.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import List, Optional, Tuple
8
+
9
+ import torch
10
+ from data.transforms import (
11
+ ALBEFTextTransform,
12
+ testing_image_transform,
13
+ training_image_transform,
14
+ )
15
+ from data.vqa_dataset import VQADataset
16
+ from pytorch_lightning import LightningDataModule
17
+ from torch import Tensor
18
+ from torch.nn.utils.rnn import pad_sequence
19
+ from torch.utils.data import DataLoader, DistributedSampler
20
+
21
+
22
+ class VQADataModule(LightningDataModule):
23
+ """
24
+ The Data Module for Visual Question Answering task.
25
+
26
+ Args:
27
+ train_files (List[str]): The paths to training json files.
28
+ test_files (List[str]): The paths to testing json files.
29
+ answer_list (str): The path to the answers list.
30
+ vqa_root (str): The path to vqa data directory.
31
+ vg_root (str): The path to vg data directory.
32
+ batch_size (int): The sampling batch size.
33
+ num_workers (int): The number of workers for the distributed mode.
34
+ """
35
+
36
+ def __init__(
37
+ self,
38
+ train_files: List[str],
39
+ test_files: List[str],
40
+ answer_list: str,
41
+ vqa_root: str,
42
+ vg_root: str,
43
+ batch_size: int,
44
+ num_workers: int,
45
+ ) -> None:
46
+ super().__init__()
47
+ self.train_dataset = VQADataset(
48
+ train_files,
49
+ vqa_root,
50
+ vg_root,
51
+ image_transform=training_image_transform(),
52
+ question_transform=ALBEFTextTransform(
53
+ truncate=True, max_seq_len=25, add_end_token=False
54
+ ),
55
+ answer_transform=ALBEFTextTransform(do_pre_process=False),
56
+ split="train",
57
+ )
58
+
59
+ self.test_dataset = VQADataset(
60
+ test_files,
61
+ vqa_root,
62
+ vg_root,
63
+ image_transform=testing_image_transform(),
64
+ question_transform=ALBEFTextTransform(add_end_token=False),
65
+ answer_transform=ALBEFTextTransform(do_pre_process=False),
66
+ split="test",
67
+ answer_list=answer_list,
68
+ )
69
+
70
+ self.batch_size = batch_size
71
+ self.num_workers = num_workers
72
+
73
+ def _get_sampler(
74
+ self,
75
+ dataset: VQADataset,
76
+ shuffle: bool,
77
+ is_distributed: bool,
78
+ num_tasks: int,
79
+ global_rank: int,
80
+ ) -> Optional[DistributedSampler]:
81
+ if not is_distributed:
82
+ return None
83
+
84
+ return DistributedSampler(
85
+ dataset, num_replicas=num_tasks, rank=global_rank, shuffle=shuffle
86
+ )
87
+
88
+ def train_dataloader(
89
+ self,
90
+ is_distributed: bool = False,
91
+ num_tasks: int = 0,
92
+ global_rank: int = 0,
93
+ drop_last: bool = True,
94
+ ) -> DataLoader:
95
+ """
96
+ DataLoader Outputs:
97
+ images (Tensor): Tensor of shape (B, C, W, H) of image inputs.
98
+ questions (Tensor): Tensor of shape (B, L) of question inputs.
99
+ question_atts (Tensor): Tensor of shape (B, L) of question attention mask.
100
+ answers (Tensor): Tensor of shape (N, M) of answer inputs.
101
+ N >= B because a vqa sample can have multiple answers.
102
+ answer_atts (Tensor): Tensor of shape (N, M) of answer attention mask.
103
+ weights (Tensor): Tensor of shape (N) of answer weights.
104
+ ans_lengths (List[int]): List of length B and sum N where
105
+ ans_lengths[i] = number of answers for images[i] and questions[i].
106
+ """
107
+ sampler = self._get_sampler(
108
+ dataset=self.train_dataset,
109
+ shuffle=True,
110
+ is_distributed=is_distributed,
111
+ num_tasks=num_tasks,
112
+ global_rank=global_rank,
113
+ )
114
+ shuffle = sampler is None
115
+ return DataLoader(
116
+ self.train_dataset,
117
+ batch_size=self.batch_size,
118
+ num_workers=self.num_workers,
119
+ pin_memory=True,
120
+ sampler=sampler,
121
+ shuffle=shuffle,
122
+ collate_fn=vqa_train_collate_fn,
123
+ drop_last=drop_last,
124
+ )
125
+
126
+ def test_dataloader(
127
+ self,
128
+ is_distributed: bool = False,
129
+ num_tasks: int = 0,
130
+ global_rank: int = 0,
131
+ drop_last=False,
132
+ ) -> DataLoader:
133
+ """
134
+ DataLoader Outputs:
135
+ images (Tensor): Tensor of shape (B, C, W, H) of image inputs.
136
+ questions (Tensor): Tensor of shape (B, L) of question inputs.
137
+ question_atts (Tensor): Tensor of shape (B, L) of question attention mask.
138
+ question_ids (List): List of length B of question ids.
139
+ """
140
+ sampler = self._get_sampler(
141
+ dataset=self.test_dataset,
142
+ shuffle=False,
143
+ is_distributed=is_distributed,
144
+ num_tasks=num_tasks,
145
+ global_rank=global_rank,
146
+ )
147
+ return DataLoader(
148
+ self.test_dataset,
149
+ batch_size=self.batch_size,
150
+ num_workers=self.num_workers,
151
+ pin_memory=True,
152
+ sampler=sampler,
153
+ shuffle=False,
154
+ collate_fn=vqa_test_collate_fn,
155
+ drop_last=drop_last,
156
+ )
157
+
158
+
159
+ def vqa_train_collate_fn(
160
+ batch: List[Tuple[Tensor, Tensor, List[Tensor], List[float]]],
161
+ ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, List[int]]:
162
+ image_list = []
163
+ question_list = []
164
+ answer_list = []
165
+ weight_list = []
166
+ ans_lengths = []
167
+ for image, question, answer, weights in batch:
168
+ image_list.append(image)
169
+ question_list.append(question)
170
+ answer_list += answer
171
+ weight_list += weights
172
+ ans_lengths.append(len(answer))
173
+ images = torch.stack(image_list, dim=0)
174
+ questions = pad_sequence(question_list, batch_first=True)
175
+ question_atts = (questions != 0).type(torch.long)
176
+ answers = pad_sequence(answer_list, batch_first=True)
177
+ answer_atts = (answers != 0).type(torch.long)
178
+ weights = torch.Tensor(weight_list)
179
+ return (
180
+ images,
181
+ questions,
182
+ question_atts,
183
+ answers,
184
+ answer_atts,
185
+ weights,
186
+ ans_lengths,
187
+ )
188
+
189
+
190
+ def vqa_test_collate_fn(
191
+ batch: List[Tuple[Tensor, Tensor, int]],
192
+ ) -> Tuple[Tensor, Tensor, Tensor, List[int]]:
193
+ image_list, question_list, question_ids = [], [], []
194
+ for image, question, question_id in batch:
195
+ image_list.append(image)
196
+ question_list.append(question)
197
+ question_ids.append(question_id)
198
+ images = torch.stack(image_list, dim=0)
199
+ questions = pad_sequence(question_list, batch_first=True)
200
+ question_atts = (questions != 0).type(torch.long)
201
+ return (
202
+ images,
203
+ questions,
204
+ question_atts,
205
+ question_ids,
206
+ )
multimodal/examples/albef/data/vqa_dataset.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import json
8
+ import os
9
+ from typing import Callable, List, Tuple, Union
10
+
11
+ import torch
12
+
13
+ from PIL import Image
14
+ from torch import Tensor
15
+ from torch.utils.data import Dataset
16
+
17
+
18
+ class VQADataset(Dataset):
19
+ """
20
+ Create the dataset for VQA task.
21
+
22
+ Args:
23
+ ann_file (List[str]): The paths to annotation json files.
24
+ vqa_root (str): The path to vqa data directory.
25
+ vg_root (str): The path to vg data directory.
26
+ image_transform (Callable[[Image.Image], Tensor]): image data transform.
27
+ question_transform (Callable[[Union[List[str], str]], Tensor]): text data transform for questions.
28
+ answer_transform (Callable[[Union[List[str], str]], Tensor]): text data transform for answers.
29
+ split (str): Indicates train or test. Default is train.
30
+ answer_list (str): The path to the answers list. Required for test split.
31
+
32
+ Dataset Outputs:
33
+ if split is train:
34
+ image (Tensor): Transformed image input tensor of shape (C, W, H).
35
+ question (Tensor): Transformed question token input ids.
36
+ answers (List[Tensor]): List of transformed answers token input ids.
37
+ answer_weights (List[float]): List of answer weights.
38
+ answer_weights[i] is proportional to the number of occurences of answers[i]
39
+ if split is test:
40
+ image (Tensor): Transformed image input tensor of shape (C, W, H).
41
+ question (Tensor): Transformed text token input ids.
42
+ question_id (int): The question sample id.
43
+ """
44
+
45
+ def __init__(
46
+ self,
47
+ ann_file: List[str],
48
+ vqa_root: str,
49
+ vg_root: str,
50
+ image_transform: Callable[[Image.Image], Tensor],
51
+ question_transform: Callable[[Union[List[str], str]], Tensor],
52
+ answer_transform: Callable[[Union[List[str], str]], Tensor],
53
+ split: str = "train",
54
+ answer_list: str = None,
55
+ ) -> None:
56
+ self.ann = []
57
+ for f in ann_file:
58
+ self.ann += json.load(open(f, "r"))
59
+
60
+ self.vqa_root = vqa_root
61
+ self.vg_root = vg_root
62
+ self.image_transform = image_transform
63
+ self.question_transform = question_transform
64
+ self.answer_transform = answer_transform
65
+ self.split = split
66
+
67
+ if split == "test":
68
+ self.answer_list = json.load(open(answer_list, "r"))
69
+ self.answer_input_ids = self.answer_transform(self.answer_list)
70
+ self.answer_attention_mask = (self.answer_input_ids != 0).type(torch.long)
71
+
72
+ def __len__(self) -> int:
73
+ return len(self.ann)
74
+
75
+ def __getitem__(
76
+ self, index: int
77
+ ) -> Union[
78
+ Tuple[Tensor, Tensor, int], Tuple[Tensor, Tensor, List[Tensor], List[float]]
79
+ ]:
80
+ ann = self.ann[index]
81
+
82
+ image_root = self.vqa_root if ann["dataset"] == "vqa" else self.vg_root
83
+ image_path = os.path.join(image_root, ann["image"])
84
+ image = Image.open(image_path).convert("RGB")
85
+ image = self.image_transform(image)
86
+ question = self.question_transform(ann["question"])
87
+
88
+ if self.split == "test":
89
+ return image, question, ann["question_id"]
90
+
91
+ elif self.split == "train":
92
+ if ann["dataset"] == "vqa":
93
+ # Each VQA sample question has a list of answers (with potential repeats)
94
+ # answer_weight[answer] = count(answer) / len(answers for the question)
95
+ answer_weights = {}
96
+ for answer in ann["answer"]:
97
+ if answer in answer_weights.keys():
98
+ answer_weights[answer] += 1 / len(ann["answer"])
99
+ else:
100
+ answer_weights[answer] = 1 / len(ann["answer"])
101
+
102
+ answers = list(answer_weights.keys())
103
+ answer_weights = list(answer_weights.values())
104
+
105
+ elif ann["dataset"] == "vg":
106
+ # A VG sample question has one answer so assign it a constant weight (0.5)
107
+ answers = [ann["answer"]]
108
+ answer_weights = [0.5]
109
+
110
+ answers = list(self.answer_transform(answers))
111
+
112
+ return image, question, answers, answer_weights
113
+
114
+ else:
115
+ raise ValueError("dataset split should be train or test")
multimodal/examples/common/data/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from .multidata import * # noqa F401
multimodal/examples/common/data/multidata.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import random
8
+ import warnings
9
+ from functools import partial
10
+ from typing import Callable, List, Optional
11
+
12
+ import torch
13
+ from pytorch_lightning import LightningDataModule
14
+
15
+
16
+ class MultiDataLoader:
17
+ # NOTE: Please check MMF's MultiDataLoader if you want to support
18
+ # epoch based sampling funcs.
19
+ def __init__(
20
+ self,
21
+ loaders: List[torch.utils.data.DataLoader],
22
+ sampling_func: Optional[Callable] = None,
23
+ ):
24
+ """MultiDataLoader takes in a list of dataloaders and a sampling function
25
+ and cycles between these dataloaders after each batch based on the index
26
+ provided by the sampling function passed. Useful for doing multi-tasking
27
+ over multiple datasets
28
+
29
+ Args:
30
+ loaders (List[torch.utils.data.DataLoader]): List of dataloaders on
31
+ which the multitasking has to be done.
32
+
33
+ sampling_func (Optional[Callable], optional): Function which will return
34
+ the next index to be selected. Defaults to equally weight sampling.
35
+ """
36
+ if loaders is None or len(loaders) == 0:
37
+ warnings.warn(
38
+ "Empty loaders passed into MultiDataLoader. This can have "
39
+ "unintended consequences."
40
+ )
41
+
42
+ if sampling_func is None:
43
+ sampling_func = partial(random.choice, range(len(loaders)))
44
+
45
+ self.sampling_func = sampling_func
46
+ self.loaders = loaders
47
+ self.num_datasets = len(self.loaders)
48
+ self.iterators = [None for _ in loaders]
49
+ self.current_index = 0
50
+ self.set_samplers()
51
+
52
+ def set_samplers(self):
53
+ self.samplers: List[torch.utils.data.Sampler] = []
54
+ for loader in self.loaders:
55
+ if hasattr(loader, "sampler"):
56
+ self.samplers.append(loader.sampler)
57
+
58
+ def __iter__(self):
59
+ self.iterators = []
60
+
61
+ for loader in self.loaders:
62
+ self.iterators.append(iter(loader))
63
+
64
+ self.change_dataloader()
65
+
66
+ return self
67
+
68
+ def __next__(self):
69
+ """
70
+ Calculation of next batch is performed using following logic.
71
+
72
+ Current chosen iterator is set in the change_dataloader function
73
+ based on the `sampling_func` function passed to `__init__` of the
74
+ dataloader which is called to get the index of next selected dataloader.
75
+
76
+ If we get the next batch from iterator without any StopIteration exception,
77
+ we return it as it is.
78
+
79
+ Epochs don't make sense in case of using `sampling_func` unless you add
80
+ extra logic to support epoch-based sampling functions. MMF does this in
81
+ a different way, so take a look at IterationStrategies there to understand
82
+ how this can be possibly done.
83
+
84
+ Think of a case of random (equal) proportional sampling for dataset x and y
85
+ where x is half the size of y. When x will complete its 2 epochs, y will
86
+ have only 1 epoch completed. **So please don't use max_epochs or epoch
87
+ based training in this case as it won't be honored**. If an iterator is
88
+ finished, we just reignite it in this case and finished iterators
89
+ variable isn't used. This means that this case will never reach the
90
+ __iter__ function ever again.
91
+
92
+
93
+ Returns:
94
+ Dict: Contains two keys, one "batch" containing the batch from current
95
+ selected dataloader and "datamodule_index" which is index of
96
+ currently selected dataloader.
97
+ """
98
+ self.change_dataloader()
99
+ try:
100
+ next_batch = next(self.current_iterator)
101
+ except StopIteration:
102
+ iterator = iter(self.loaders[self.current_index])
103
+ self.iterators[self.current_index] = iterator
104
+ self.current_iterator = iterator
105
+ next_batch = next(self.current_iterator)
106
+
107
+ return {"batch": next_batch, "datamodule_index": self.current_index}
108
+
109
+ def change_dataloader(self):
110
+ choice = 0
111
+
112
+ if self.num_datasets <= 1:
113
+ self.current_index = choice
114
+ self.current_iterator = self.iterators[self.current_index]
115
+ return
116
+
117
+ choice = [self.sampling_func()]
118
+ if torch.distributed.is_available() and torch.distributed.is_initialized():
119
+ # This broadcast is probably unnecessary with lightning if everything
120
+ # is already properly seeded. But,to be on safe side, we can still
121
+ # do this.
122
+ # There are also some smarter ways to do this to avoid any broadcasting
123
+ # by basically having a fixed generator with a fixed seed which will
124
+ # always work deterministically.
125
+ # TODO: Check if not doing this provides any speed benefits.
126
+ torch.distributed.broadcast_object_list(choice, 0)
127
+
128
+ self.current_index = choice[0]
129
+ self.current_iterator = self.iterators[self.current_index]
130
+
131
+ def set_epoch(self, epoch: int):
132
+ if torch.distributed.is_available() and torch.distributed.is_initialized():
133
+ for sampler in self.samplers:
134
+ if sampler is not None and hasattr(sampler, "set_epoch"):
135
+ sampler.set_epoch(epoch)
136
+
137
+
138
+ class MultiDataModule(LightningDataModule):
139
+ """MultiDataModule is just an abstraction over MultiDataLoader
140
+ that will allow us to integrate it with Lightning.
141
+ """
142
+
143
+ # NOTE: Add rest of the functions that should be called on child datamodules
144
+ # as required
145
+ def __init__(
146
+ self,
147
+ datamodules: List[LightningDataModule],
148
+ sampling_func: Optional[Callable] = None,
149
+ ):
150
+ super().__init__()
151
+ self.datamodules = datamodules
152
+ self.sampling_func = sampling_func
153
+ self.current_datamodule_idx = 0
154
+
155
+ def setup(self, stage=None):
156
+ for datamodule in self.datamodules:
157
+ datamodule.setup(stage)
158
+
159
+ def prepare_data(self):
160
+ for datamodule in self.datamodules:
161
+ datamodule.prepare_data()
162
+
163
+ def train_dataloader(self) -> MultiDataLoader:
164
+ # TODO: Fix assign inconsistency
165
+ return self._build_multi_dataloader("train")
166
+
167
+ def val_dataloader(self) -> MultiDataLoader:
168
+ return self._build_multi_dataloader("val")
169
+
170
+ def test_dataloader(self) -> MultiDataLoader:
171
+ return self._build_multi_dataloader("test")
172
+
173
+ def _build_multi_dataloader(self, split="train"):
174
+ dataloaders = []
175
+ for datamodule in self.datamodules:
176
+ dataloaders.append(getattr(datamodule, f"{split}_dataloader")())
177
+
178
+ return MultiDataLoader(dataloaders, self.sampling_func)
179
+
180
+ def on_before_batch_transfer(self, batch, *args):
181
+ batch, index = batch["batch"], batch["datamodule_index"]
182
+ self.current_datamodule_idx = index
183
+ return self.datamodules[self.current_datamodule_idx].on_before_batch_transfer(
184
+ batch, *args
185
+ )
186
+
187
+ def on_after_batch_transfer(self, batch, *args):
188
+ return self.datamodules[self.current_datamodule_idx].on_after_batch_transfer(
189
+ batch, *args
190
+ )
191
+
192
+ def teardown(self, stage):
193
+ for datamodule in self.datamodules:
194
+ datamodule.teardown(stage)
multimodal/examples/flava/callbacks/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from .multimodal_eval import * # noqa F401
multimodal/examples/flava/callbacks/multimodal_eval.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import logging
8
+
9
+ import torch
10
+ from flava.data import default_text_transform, VL_MAX_LENGTH_DEFAULT
11
+ from flava.data.imagenet_zeroshot_data import (
12
+ imagenet_classnames,
13
+ openai_imagenet_template,
14
+ )
15
+ from pytorch_lightning import Callback, LightningDataModule
16
+ from pytorch_lightning.utilities import rank_zero_only
17
+ from tqdm import tqdm
18
+
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ def _zero_shot_classifier(model, device, text_transform, *args, **kwargs):
24
+ zeroshot_weights = []
25
+ for classname in tqdm(imagenet_classnames):
26
+ texts = text_transform(
27
+ [template(classname) for template in openai_imagenet_template]
28
+ )["input_ids"]
29
+ texts = texts.to(device)
30
+ class_embeddings = model.encode_text(texts)
31
+ class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
32
+ class_embedding = class_embeddings.mean(dim=0)
33
+ class_embedding /= class_embedding.norm()
34
+ zeroshot_weights.append(class_embedding)
35
+
36
+ zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(device)
37
+ return zeroshot_weights
38
+
39
+
40
+ def _accuracy(output, target, topk=(1,)):
41
+ pred = output.topk(max(topk), 1, True, True)[1].t()
42
+ correct = pred.eq(target.view(1, -1).expand_as(pred))
43
+ return [
44
+ float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy())
45
+ for k in topk
46
+ ]
47
+
48
+
49
+ @rank_zero_only
50
+ def run_imagenet_zero_shot(model, dataloader, device, text_transform, *args, **kwargs):
51
+ logger.info("Starting ImageNet Zero-Shot Eval")
52
+ logger.info("Building classifier")
53
+ classifier = _zero_shot_classifier(model, device, text_transform)
54
+ logger.info("Classifier built")
55
+ top1, top5, n = 0.0, 0.0, 0.0
56
+ for sample in tqdm(dataloader):
57
+ images = sample["image"]
58
+ target = sample["label"]
59
+ images = images.to(device)
60
+ target = target.to(device)
61
+
62
+ # predict
63
+ # if hasattr(model, "module"):
64
+ # image_features = model.module.encode_image({"image": images})
65
+ # else:
66
+ image_features = model.encode_image(images)
67
+ image_features /= image_features.norm(dim=-1, keepdim=True)
68
+ logits = 100.0 * image_features @ classifier
69
+
70
+ # measure accuracy
71
+ acc1, acc5 = _accuracy(logits, target, topk=(1, 5))
72
+ top1 += acc1
73
+ top5 += acc5
74
+ n += images.size(0)
75
+
76
+ top1 = top1 / n
77
+ top5 = top5 / n
78
+ results = {}
79
+ results["imagenet-zeroshot-val-top1"] = top1
80
+ results["imagenet-zeroshot-val-top5"] = top5
81
+ return results
82
+
83
+
84
+ class MultimodalEvalCallback(Callback):
85
+ def __init__(self, imagenet_datamodule: LightningDataModule, *args, **kwargs):
86
+ super().__init__()
87
+ self.imagenet_val_dataloader = imagenet_datamodule.val_dataloader()
88
+ self.text_transform = default_text_transform(
89
+ max_text_length=VL_MAX_LENGTH_DEFAULT
90
+ )
91
+
92
+ @torch.no_grad()
93
+ def on_validation_start(self, trainer, pl_module, **kwargs) -> None:
94
+ metrics = run_imagenet_zero_shot(
95
+ pl_module.model,
96
+ self.imagenet_val_dataloader,
97
+ pl_module.device,
98
+ self.text_transform,
99
+ )
100
+ if metrics is not None:
101
+ for key in metrics:
102
+ self.log(
103
+ f"val/{key}",
104
+ metrics[key],
105
+ prog_bar=True,
106
+ logger=True,
107
+ rank_zero_only=True,
108
+ )
multimodal/examples/flava/configs/finetuning/qnli.yaml ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Note that in original FLAVA paper, only Logistic Regression numbers were provided for image datasets.
2
+ _target_: flava.definitions.FLAVAArguments
3
+ training:
4
+ _target_: flava.definitions.TrainingArguments
5
+ lightning:
6
+ max_steps: 33112
7
+ gpus: 1
8
+ val_check_interval: 1000
9
+ num_sanity_val_steps: 0
10
+ strategy: ddp
11
+ lightning_checkpoint:
12
+ dirpath: "."
13
+ filename: flava-{epoch:02d}-{step}
14
+ save_last: true
15
+ every_n_train_steps: 1000
16
+ save_on_train_epoch_end: true
17
+ verbose: true
18
+ monitor: validation/accuracy/classification
19
+ mode: max
20
+ lightning_load_from_checkpoint: null
21
+ seed: -1
22
+ batch_size: 32
23
+ num_workers: 4
24
+ learning_rate: 1e-5
25
+ adam_eps: 1e-6
26
+ adam_weight_decay: 0.1
27
+ adam_betas:
28
+ - 0.9
29
+ - 0.98
30
+ warmup_steps: 1986
31
+
32
+
33
+ datasets:
34
+ _target_: flava.definitions.TrainingDatasetsInfo
35
+ selected:
36
+ - text
37
+ num_classes: 2
38
+ text:
39
+ _target_: flava.definitions.TrainingSingleDatasetInfo
40
+ train:
41
+ - _target_: flava.definitions.HFDatasetInfo
42
+ key: glue
43
+ subset: qnli
44
+ rename_columns:
45
+ - ["question", "sentence1"]
46
+ - ["sentence", "sentence2"]
47
+ datamodule_extra_kwargs:
48
+ text_columns: ["sentence1", "sentence2"]
multimodal/examples/flava/configs/finetuning/rendered_sst2.yaml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Note that in original FLAVA paper, only Logistic Regression numbers were provided for image datasets.
2
+ _target_: flava.definitions.FLAVAArguments
3
+ training:
4
+ _target_: flava.definitions.TrainingArguments
5
+ lightning:
6
+ max_steps: 20935
7
+ gpus: -1
8
+ val_check_interval: 100
9
+ num_sanity_val_steps: 0
10
+ strategy: ddp
11
+ lightning_checkpoint:
12
+ dirpath: "."
13
+ filename: flava-{epoch:02d}-{step}
14
+ save_last: true
15
+ every_n_train_steps: 1000
16
+ save_on_train_epoch_end: true
17
+ verbose: true
18
+ lightning_load_from_checkpoint: null
19
+ seed: -1
20
+ batch_size: 32
21
+ num_workers: 4
22
+ learning_rate: 1e-5
23
+ adam_eps: 1e-8
24
+ adam_weight_decay: 1e-2
25
+ warmup_steps: 1256
26
+
27
+
28
+ datasets:
29
+ _target_: flava.definitions.TrainingDatasetsInfo
30
+ selected:
31
+ - image
32
+ num_classes: 2
33
+ image:
34
+ _target_: flava.definitions.TrainingSingleDatasetInfo
35
+ train:
36
+ - _target_: flava.definitions.TorchVisionDatasetInfo
37
+ key: RenderedSST2
multimodal/examples/flava/configs/pretraining/debug.yaml ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: flava.definitions.FLAVAArguments
2
+ training:
3
+ _target_: flava.definitions.TrainingArguments
4
+ lightning:
5
+ max_steps: 450000
6
+ gpus: -1
7
+ val_check_interval: 10000
8
+ num_sanity_val_steps: 0
9
+ strategy: ddp
10
+ lightning_checkpoint:
11
+ dirpath: "."
12
+ filename: flava-{epoch:02d}-{step}
13
+ save_last: true
14
+ every_n_train_steps: 1000
15
+ save_on_train_epoch_end: true
16
+ verbose: true
17
+ lightning_load_from_checkpoint: null
18
+ seed: -1
19
+ batch_size: 8
20
+ num_workers: 4
21
+ learning_rate: 2e-4
22
+ adam_eps: 1e-8
23
+ adam_weight_decay: 1e-2
24
+ warmup_steps: 2000
25
+
26
+ datasets:
27
+ _target_: flava.definitions.TrainingDatasetsInfo
28
+ selected:
29
+ - image
30
+ - vl
31
+ - text
32
+ image:
33
+ _target_: flava.definitions.TrainingSingleDatasetInfo
34
+ train:
35
+ - _target_: flava.definitions.HFDatasetInfo
36
+ key: imagenet-1k
37
+ subset: default
38
+ text:
39
+ _target_: flava.definitions.TrainingSingleDatasetInfo
40
+ train:
41
+ - _target_: flava.definitions.HFDatasetInfo
42
+ key: wikitext
43
+ subset: wikitext-103-raw-v1
44
+ datamodule_extra_kwargs:
45
+ text_columns: ["text"]
46
+ vl:
47
+ _target_: flava.definitions.TrainingSingleDatasetInfo
48
+ train:
49
+ - _target_: flava.definitions.HFDatasetInfo
50
+ key: red_caps
51
+ subset: jellyfish
52
+ rename_columns:
53
+ - ["caption", "text"]
54
+ val:
55
+ - _target_: flava.definitions.HFDatasetInfo
56
+ key: red_caps
57
+ subset: jellyfish
58
+ rename_columns:
59
+ - ["caption", "text"]
60
+ split_key_mapping:
61
+ validation: train
multimodal/examples/flava/data/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from .transforms import * # noqa F401
8
+ from .utils import * # noqa F401
9
+ from .imagenet_zeroshot_data import * # noqa F401
10
+ from .datamodules import * # noqa F401
multimodal/examples/flava/data/datamodules.py ADDED
@@ -0,0 +1,529 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import os
8
+ from functools import partial
9
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
10
+
11
+ import torch
12
+ import torchvision
13
+ from flava.definitions import HFDatasetInfo, TorchVisionDatasetInfo
14
+ from pytorch_lightning import LightningDataModule
15
+ from transformers import (
16
+ BertTokenizer,
17
+ DataCollatorForLanguageModeling,
18
+ DataCollatorForWholeWordMask,
19
+ DefaultDataCollator,
20
+ TRANSFORMERS_CACHE,
21
+ )
22
+ from transformers.data.data_collator import torch_default_data_collator
23
+
24
+ from .transforms import (
25
+ default_image_pretraining_transforms,
26
+ default_text_transform,
27
+ default_torchvision_transforms,
28
+ encode_text_batch,
29
+ pad_batch,
30
+ TEXT_DEFAULT_TOKENIZER,
31
+ TEXT_WHOLE_WORD_MASK_TOKENIZER,
32
+ VL_MAX_LENGTH_DEFAULT,
33
+ VLTransform,
34
+ )
35
+ from .utils import build_datasets_from_info, fetch_images
36
+
37
+
38
+ def transform_image(transform, sample):
39
+ sample.update(transform(sample["image"]))
40
+ return sample
41
+
42
+
43
+ class DataCollatorForWholeWordMaskRetainingBatch(DataCollatorForWholeWordMask):
44
+ def torch_call(
45
+ self, examples: List[Union[List[int], Any, Dict[str, Any]]]
46
+ ) -> Dict[str, Any]:
47
+ masked_batch = super().torch_call(examples)
48
+ examples = torch_default_data_collator(examples)
49
+ examples["input_ids"] = masked_batch["input_ids"]
50
+ examples["labels"] = masked_batch["labels"]
51
+ return examples
52
+
53
+
54
+ class ImageDataModule(LightningDataModule):
55
+ def __init__(
56
+ self,
57
+ train_infos: List[HFDatasetInfo],
58
+ val_infos: Optional[List[HFDatasetInfo]] = None,
59
+ transforms: Optional[Tuple[Callable, Callable]] = None,
60
+ batch_size: int = 32,
61
+ num_workers: int = 4,
62
+ allow_uneven_batches: bool = False,
63
+ **kwargs: Any,
64
+ ):
65
+ super().__init__()
66
+ self.train_dataset_infos = train_infos
67
+ self.val_dataset_infos = val_infos
68
+ if self.val_dataset_infos is None:
69
+ self.val_dataset_infos = train_infos
70
+
71
+ self.batch_size = batch_size
72
+ self.num_workers = num_workers
73
+ self.allow_uneven_batches = allow_uneven_batches
74
+
75
+ if transforms is None:
76
+ transforms = default_image_pretraining_transforms()
77
+
78
+ self.train_transform, self.test_transform = transforms
79
+
80
+ def setup(self, stage=None):
81
+ train_transform = partial(transform_image, self.train_transform)
82
+ val_transform = partial(transform_image, self.test_transform)
83
+
84
+ self.train_dataset = build_datasets_from_info(
85
+ self.train_dataset_infos, split="train"
86
+ )
87
+ self.train_dataset.set_transform(train_transform)
88
+ self.val_dataset = build_datasets_from_info(
89
+ self.val_dataset_infos, split="validation"
90
+ )
91
+ self.val_dataset.set_transform(val_transform)
92
+
93
+ def train_dataloader(self):
94
+ return torch.utils.data.DataLoader(
95
+ self.train_dataset,
96
+ batch_size=self.batch_size,
97
+ num_workers=self.num_workers,
98
+ sampler=None,
99
+ shuffle=True,
100
+ # uneven batches can cause distributed issues,
101
+ # drop last batch to prevent those.
102
+ # ideally, we don't need to drop these for unimodal cases
103
+ # but just to be safe
104
+ drop_last=True,
105
+ )
106
+
107
+ def val_dataloader(self):
108
+ return torch.utils.data.DataLoader(
109
+ self.val_dataset,
110
+ batch_size=self.batch_size,
111
+ num_workers=self.num_workers,
112
+ sampler=None,
113
+ shuffle=False,
114
+ # uneven batches can cause distributed issues,
115
+ # drop last batch to prevent those.
116
+ # ideally, we don't need to drop these for unimodal cases
117
+ # but just to be safe
118
+ drop_last=True,
119
+ )
120
+
121
+ def test_dataloader(self):
122
+ return self.val_dataloader()
123
+
124
+ def on_before_batch_transfer(self, batch, *args):
125
+ if batch["label"].size(0) < self.batch_size and not self.allow_uneven_batches:
126
+ batch = pad_batch(batch, self.batch_size)
127
+ return batch
128
+
129
+
130
+ class TextDataModule(LightningDataModule):
131
+ def __init__(
132
+ self,
133
+ train_infos: List[HFDatasetInfo],
134
+ text_columns: List[str],
135
+ val_infos: Optional[List[HFDatasetInfo]] = None,
136
+ tokenizer: Optional[Callable] = None,
137
+ max_length: int = 512,
138
+ batch_size: int = 32,
139
+ num_workers: int = 4,
140
+ allow_uneven_batches: bool = False,
141
+ **kwargs: Any,
142
+ ):
143
+ super().__init__()
144
+ self.train_dataset_infos = train_infos
145
+ self.text_columns = text_columns
146
+ self.val_dataset_infos = val_infos
147
+ if self.val_dataset_infos is None:
148
+ self.val_dataset_infos = train_infos
149
+ self.tokenizer = tokenizer
150
+ self.max_length = max_length
151
+ self.batch_size = batch_size
152
+ self.num_workers = num_workers
153
+ self.allow_uneven_batches = allow_uneven_batches
154
+
155
+ def setup(self, stage=None):
156
+ if self.tokenizer is None:
157
+ self.tokenizer = BertTokenizer.from_pretrained(TEXT_DEFAULT_TOKENIZER)
158
+ transform = partial(
159
+ encode_text_batch,
160
+ tokenizer=self.tokenizer,
161
+ padding="max_length",
162
+ max_length=self.max_length,
163
+ truncation=True,
164
+ return_tensors="pt",
165
+ return_special_tokens_mask=True,
166
+ text_columns=self.text_columns,
167
+ return_batch=True,
168
+ )
169
+ self.train_dataset = build_datasets_from_info(
170
+ self.train_dataset_infos, split="train"
171
+ )
172
+ self.train_dataset.set_transform(transform)
173
+ self.val_dataset = build_datasets_from_info(
174
+ self.val_dataset_infos, split="validation"
175
+ )
176
+ self.val_dataset.set_transform(transform)
177
+
178
+ def train_dataloader(self):
179
+ return self._build_dataloader(self.train_dataset)
180
+
181
+ def val_dataloader(self):
182
+ return self._build_dataloader(self.val_dataset, shuffle=False)
183
+
184
+ def _build_dataloader(self, dataset, drop_last=False, shuffle=True):
185
+ return torch.utils.data.DataLoader(
186
+ dataset,
187
+ batch_size=self.batch_size,
188
+ num_workers=self.num_workers,
189
+ sampler=None,
190
+ shuffle=shuffle,
191
+ collate_fn=self._build_collator(),
192
+ drop_last=drop_last,
193
+ )
194
+
195
+ def _build_collator(self):
196
+ return DefaultDataCollator()
197
+
198
+ def on_before_batch_transfer(self, batch, *args):
199
+ batch.pop("token_type_ids", None)
200
+ mask = batch.pop("attention_mask", None)
201
+ if mask.size(0) < self.batch_size and not self.allow_uneven_batches:
202
+ batch = pad_batch(batch, self.batch_size)
203
+ return batch
204
+
205
+ def on_after_batch_transfer(self, batch, *args):
206
+ batch["text"] = batch.pop("input_ids")
207
+ return batch
208
+
209
+
210
+ class MLMDataModule(TextDataModule):
211
+ def __init__(
212
+ self,
213
+ train_infos: List[HFDatasetInfo],
214
+ text_columns: List[str],
215
+ val_infos: Optional[List[HFDatasetInfo]] = None,
216
+ mlm_probability: float = 0.15,
217
+ ignore_index: int = -1,
218
+ **kwargs: Any,
219
+ ):
220
+ super().__init__(train_infos, text_columns, val_infos, **kwargs)
221
+ self.mlm_probability = mlm_probability
222
+ self.ignore_index = ignore_index
223
+
224
+ def setup(self, stage=None):
225
+ if self.tokenizer is None:
226
+ self.tokenizer = BertTokenizer.from_pretrained(TEXT_DEFAULT_TOKENIZER)
227
+ transform = partial(
228
+ encode_text_batch,
229
+ tokenizer=self.tokenizer,
230
+ padding="max_length",
231
+ max_length=self.max_length,
232
+ truncation=True,
233
+ return_tensors="pt",
234
+ return_special_tokens_mask=True,
235
+ text_columns=self.text_columns,
236
+ return_batch=False,
237
+ )
238
+ self.train_dataset = build_datasets_from_info(
239
+ self.train_dataset_infos, split="train"
240
+ )
241
+ self.train_dataset.set_transform(transform)
242
+ self.val_dataset = build_datasets_from_info(
243
+ self.val_dataset_infos, split="validation"
244
+ )
245
+ self.val_dataset.set_transform(transform)
246
+
247
+ def _build_dataloader(self, dataset, drop_last=True, shuffle=True):
248
+ # uneven batches can cause distributed issues,
249
+ # drop last batch to prevent those.
250
+ # ideally, we don't need to drop these for unimodal cases
251
+ # but just to be safe
252
+ return super()._build_dataloader(dataset, drop_last=drop_last, shuffle=shuffle)
253
+
254
+ def _build_collator(self):
255
+ return DataCollatorForLanguageModeling(
256
+ self.tokenizer, mlm_probability=self.mlm_probability
257
+ )
258
+
259
+ def on_after_batch_transfer(self, batch, *args):
260
+ batch["text_masked"] = batch.pop("input_ids")
261
+ batch["mlm_labels"] = batch.pop("labels")
262
+ batch["mlm_labels"][batch["mlm_labels"] == -100] = self.ignore_index
263
+ return batch
264
+
265
+
266
+ class VLDataModule(LightningDataModule):
267
+ def __init__(
268
+ self,
269
+ train_infos: List[HFDatasetInfo],
270
+ val_infos: List[HFDatasetInfo],
271
+ text_transform: Optional[Callable] = None,
272
+ image_transforms: Optional[Tuple[Callable, Callable]] = None,
273
+ mlm_probablity: float = 0.15,
274
+ batch_size: int = 32,
275
+ num_workers: int = 4,
276
+ finetuning: bool = False,
277
+ ignore_index: int = -1,
278
+ itm_probability: float = 0.1,
279
+ allow_uneven_batches: bool = False,
280
+ fetch_num_threads: int = 4,
281
+ fetch_retries: int = 0,
282
+ fetch_sleep_timer: int = 0,
283
+ fetch_timeout: Optional[float] = None,
284
+ fetch_batch_size: int = 50,
285
+ **kwargs,
286
+ ):
287
+ super().__init__()
288
+
289
+ self.train_dataset_infos = train_infos
290
+ self.val_dataset_infos = val_infos
291
+ if self.val_dataset_infos is None:
292
+ self.val_dataset_infos = train_infos
293
+ if image_transforms is None:
294
+ if not finetuning:
295
+ image_transforms = default_image_pretraining_transforms()
296
+ else:
297
+ image_transforms = default_torchvision_transforms(use_dict=True)
298
+
299
+ self.train_image_transform, self.test_image_transform = image_transforms
300
+ self.text_transform = text_transform
301
+ self.mlm_probability = mlm_probablity
302
+ self.batch_size = batch_size
303
+ self.num_workers = num_workers
304
+ self.ignore_index = ignore_index
305
+ self.itm_probability = itm_probability
306
+ self.allow_uneven_batches = allow_uneven_batches
307
+ self.fetch_num_threads = fetch_num_threads
308
+ self.fetch_retries = fetch_retries
309
+ self.fetch_sleep_timer = fetch_sleep_timer
310
+ self.fetch_timeout = fetch_timeout
311
+ self.fetch_batch_size = fetch_batch_size
312
+
313
+ def setup(self, stage=None):
314
+ if self.text_transform is None:
315
+ # TODO Update to use whole word mask vocab
316
+ text_tokenizer = BertTokenizer.from_pretrained(
317
+ TEXT_WHOLE_WORD_MASK_TOKENIZER
318
+ )
319
+ self.text_transform = default_text_transform(
320
+ text_tokenizer, max_text_length=VL_MAX_LENGTH_DEFAULT
321
+ )
322
+ self.text_tokenizer = self.text_transform.keywords["tokenizer"]
323
+ train_vl_transform = VLTransform(
324
+ self.train_image_transform, self.text_transform
325
+ )
326
+ val_vl_transform = VLTransform(self.test_image_transform, self.text_transform)
327
+
328
+ train_dataset = build_datasets_from_info(
329
+ self.train_dataset_infos, split="train"
330
+ )
331
+ train_dataset = train_dataset.map(
332
+ fetch_images,
333
+ batched=True,
334
+ batch_size=self.fetch_batch_size,
335
+ fn_kwargs={
336
+ "num_threads": self.fetch_num_threads,
337
+ "timeout": self.fetch_timeout,
338
+ "retries": self.fetch_retries,
339
+ "sleep_timer": self.fetch_sleep_timer,
340
+ },
341
+ )
342
+ train_dataset = train_dataset.filter(
343
+ lambda example: example["image"] is not None
344
+ )
345
+ self.train_dataset = train_dataset
346
+ self.train_dataset.set_transform(
347
+ partial(
348
+ train_vl_transform,
349
+ dataset=train_dataset.filter(lambda example: True),
350
+ itm_probability=self.itm_probability,
351
+ )
352
+ )
353
+
354
+ val_dataset = build_datasets_from_info(
355
+ self.val_dataset_infos, split="validation"
356
+ )
357
+
358
+ val_dataset = val_dataset.map(
359
+ fetch_images,
360
+ batched=True,
361
+ batch_size=self.fetch_batch_size,
362
+ fn_kwargs={
363
+ "num_threads": self.fetch_num_threads,
364
+ "timeout": self.fetch_timeout,
365
+ "retries": self.fetch_retries,
366
+ "sleep_timer": self.fetch_sleep_timer,
367
+ },
368
+ )
369
+ val_dataset = val_dataset.filter(lambda example: example["image"] is not None)
370
+ self.val_dataset = val_dataset
371
+ self.val_dataset.set_transform(
372
+ partial(
373
+ val_vl_transform,
374
+ dataset=self.val_dataset.filter(
375
+ lambda example: True
376
+ ), # Pass a copy to transform
377
+ itm_probability=self.itm_probability,
378
+ )
379
+ )
380
+
381
+ def train_dataloader(self):
382
+ return torch.utils.data.DataLoader(
383
+ self.train_dataset,
384
+ batch_size=self.batch_size,
385
+ num_workers=self.num_workers,
386
+ sampler=None,
387
+ shuffle=True,
388
+ collate_fn=self._build_collator(),
389
+ # uneven batches can cause distributed issues,
390
+ # drop last batch to prevent those.
391
+ drop_last=True,
392
+ )
393
+
394
+ def val_dataloader(self):
395
+ return torch.utils.data.DataLoader(
396
+ self.val_dataset,
397
+ batch_size=self.batch_size,
398
+ num_workers=self.num_workers,
399
+ sampler=None,
400
+ shuffle=False,
401
+ collate_fn=self._build_collator(),
402
+ # uneven batches can cause distributed issues,
403
+ # drop last batch to prevent those.
404
+ drop_last=True,
405
+ )
406
+
407
+ def _build_collator(self):
408
+ return DataCollatorForWholeWordMaskRetainingBatch(
409
+ self.text_tokenizer, mlm_probability=self.mlm_probability
410
+ )
411
+
412
+ def on_before_batch_transfer(self, batch, *args):
413
+ batch.pop("token_type_ids", None)
414
+ mask = batch.pop("attention_mask", None)
415
+ if (
416
+ mask is not None
417
+ and mask.size(0) < self.batch_size
418
+ and not self.allow_uneven_batches
419
+ ):
420
+ batch = pad_batch(batch, self.batch_size)
421
+ return batch
422
+
423
+ def on_after_batch_transfer(self, batch, *args):
424
+ text_masked = batch.pop("input_ids")
425
+ mlm_labels = batch.pop("labels", None)
426
+ mlm_labels[mlm_labels == -100] = self.ignore_index
427
+ text = text_masked.detach().clone()
428
+ text[mlm_labels != -1] = mlm_labels[mlm_labels != -1]
429
+ batch.update(
430
+ {"mlm_labels": mlm_labels, "text": text, "text_masked": text_masked}
431
+ )
432
+ return batch
433
+
434
+
435
+ class TorchVisionDataModule(LightningDataModule):
436
+ def __init__(
437
+ self,
438
+ train_infos: List[TorchVisionDatasetInfo],
439
+ # Val info is not used for torchvision datamodule, but kept to keep things consistent
440
+ val_infos: Optional[List[TorchVisionDatasetInfo]] = None,
441
+ dataset_root: Optional[str] = None,
442
+ image_transforms: Optional[Tuple[Callable, Callable]] = None,
443
+ batch_size: int = 32,
444
+ num_workers: int = 4,
445
+ **kwargs: Any,
446
+ ):
447
+ super().__init__()
448
+ self.train_info = train_infos[0]
449
+ if val_infos is None:
450
+ val_infos = train_infos
451
+ self.val_info = val_infos[0]
452
+
453
+ self.train_class_ptr, self.train_root = self._parse_info(
454
+ self.train_info, dataset_root=dataset_root
455
+ )
456
+ self.val_class_ptr, self.val_root = self._parse_info(
457
+ self.val_info, dataset_root=dataset_root
458
+ )
459
+
460
+ if image_transforms is None:
461
+ image_transforms = default_torchvision_transforms()
462
+
463
+ self.train_transform, self.test_transform = image_transforms
464
+ self.batch_size = batch_size
465
+ self.num_workers = num_workers
466
+
467
+ def _parse_info(
468
+ self, info: TorchVisionDatasetInfo, dataset_root: Optional[str] = None
469
+ ):
470
+ assert hasattr(
471
+ torchvision.datasets, info.key
472
+ ), f"No dataset named {info.key} present in torchvision.datasets"
473
+ class_ptr = getattr(torchvision.datasets, info.key)
474
+ if dataset_root is None:
475
+ dataset_root = os.path.join(TRANSFORMERS_CACHE, "datasets", "torchvision")
476
+ dataset_root = os.path.join(dataset_root, class_ptr.__name__.lower())
477
+ os.makedirs(dataset_root, exist_ok=True)
478
+
479
+ return class_ptr, dataset_root
480
+
481
+ def setup(self, stage=None):
482
+ self.train_dataset = self.train_class_ptr(
483
+ self.train_root,
484
+ split=self.train_info.train_split,
485
+ transform=self.train_transform,
486
+ download=True,
487
+ )
488
+
489
+ if self.val_info.has_val:
490
+ self.val_dataset = self.val_class_ptr(
491
+ self.val_root,
492
+ split=self.val_info.val_split,
493
+ transform=self.test_transform,
494
+ download=True,
495
+ )
496
+
497
+ self.test_dataset = self.val_class_ptr(
498
+ self.val_root,
499
+ split=self.val_info.test_split,
500
+ transform=self.test_transform,
501
+ download=True,
502
+ )
503
+
504
+ def train_dataloader(self):
505
+ return self._build_dataloader(self.train_dataset)
506
+
507
+ def val_dataloader(self):
508
+ if self.val_info.has_val:
509
+ dataset = self.val_dataset
510
+ else:
511
+ dataset = self.test_dataset
512
+
513
+ return self._build_dataloader(dataset, shuffle=False)
514
+
515
+ def test_dataloader(self):
516
+ return self._build_dataloader(self.test_dataset, shuffle=False)
517
+
518
+ def _build_dataloader(self, dataset: torch.utils.data.Dataset, shuffle=True):
519
+ return torch.utils.data.DataLoader(
520
+ dataset,
521
+ shuffle=shuffle,
522
+ batch_size=self.batch_size,
523
+ num_workers=self.num_workers,
524
+ )
525
+
526
+ def on_before_batch_transfer(self, batch, *args):
527
+ images, targets = batch
528
+ batch = {"image": images, "labels": targets}
529
+ return batch
multimodal/examples/flava/data/imagenet_zeroshot_data.py ADDED
@@ -0,0 +1,1095 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # File taken from https://github.com/mlfoundations/open_clip/
8
+
9
+
10
+ imagenet_classnames = [
11
+ "tench",
12
+ "goldfish",
13
+ "great white shark",
14
+ "tiger shark",
15
+ "hammerhead shark",
16
+ "electric ray",
17
+ "stingray",
18
+ "rooster",
19
+ "hen",
20
+ "ostrich",
21
+ "brambling",
22
+ "goldfinch",
23
+ "house finch",
24
+ "junco",
25
+ "indigo bunting",
26
+ "American robin",
27
+ "bulbul",
28
+ "jay",
29
+ "magpie",
30
+ "chickadee",
31
+ "American dipper",
32
+ "kite (bird of prey)",
33
+ "bald eagle",
34
+ "vulture",
35
+ "great grey owl",
36
+ "fire salamander",
37
+ "smooth newt",
38
+ "newt",
39
+ "spotted salamander",
40
+ "axolotl",
41
+ "American bullfrog",
42
+ "tree frog",
43
+ "tailed frog",
44
+ "loggerhead sea turtle",
45
+ "leatherback sea turtle",
46
+ "mud turtle",
47
+ "terrapin",
48
+ "box turtle",
49
+ "banded gecko",
50
+ "green iguana",
51
+ "Carolina anole",
52
+ "desert grassland whiptail lizard",
53
+ "agama",
54
+ "frilled-necked lizard",
55
+ "alligator lizard",
56
+ "Gila monster",
57
+ "European green lizard",
58
+ "chameleon",
59
+ "Komodo dragon",
60
+ "Nile crocodile",
61
+ "American alligator",
62
+ "triceratops",
63
+ "worm snake",
64
+ "ring-necked snake",
65
+ "eastern hog-nosed snake",
66
+ "smooth green snake",
67
+ "kingsnake",
68
+ "garter snake",
69
+ "water snake",
70
+ "vine snake",
71
+ "night snake",
72
+ "boa constrictor",
73
+ "African rock python",
74
+ "Indian cobra",
75
+ "green mamba",
76
+ "sea snake",
77
+ "Saharan horned viper",
78
+ "eastern diamondback rattlesnake",
79
+ "sidewinder rattlesnake",
80
+ "trilobite",
81
+ "harvestman",
82
+ "scorpion",
83
+ "yellow garden spider",
84
+ "barn spider",
85
+ "European garden spider",
86
+ "southern black widow",
87
+ "tarantula",
88
+ "wolf spider",
89
+ "tick",
90
+ "centipede",
91
+ "black grouse",
92
+ "ptarmigan",
93
+ "ruffed grouse",
94
+ "prairie grouse",
95
+ "peafowl",
96
+ "quail",
97
+ "partridge",
98
+ "african grey parrot",
99
+ "macaw",
100
+ "sulphur-crested cockatoo",
101
+ "lorikeet",
102
+ "coucal",
103
+ "bee eater",
104
+ "hornbill",
105
+ "hummingbird",
106
+ "jacamar",
107
+ "toucan",
108
+ "duck",
109
+ "red-breasted merganser",
110
+ "goose",
111
+ "black swan",
112
+ "tusker",
113
+ "echidna",
114
+ "platypus",
115
+ "wallaby",
116
+ "koala",
117
+ "wombat",
118
+ "jellyfish",
119
+ "sea anemone",
120
+ "brain coral",
121
+ "flatworm",
122
+ "nematode",
123
+ "conch",
124
+ "snail",
125
+ "slug",
126
+ "sea slug",
127
+ "chiton",
128
+ "chambered nautilus",
129
+ "Dungeness crab",
130
+ "rock crab",
131
+ "fiddler crab",
132
+ "red king crab",
133
+ "American lobster",
134
+ "spiny lobster",
135
+ "crayfish",
136
+ "hermit crab",
137
+ "isopod",
138
+ "white stork",
139
+ "black stork",
140
+ "spoonbill",
141
+ "flamingo",
142
+ "little blue heron",
143
+ "great egret",
144
+ "bittern bird",
145
+ "crane bird",
146
+ "limpkin",
147
+ "common gallinule",
148
+ "American coot",
149
+ "bustard",
150
+ "ruddy turnstone",
151
+ "dunlin",
152
+ "common redshank",
153
+ "dowitcher",
154
+ "oystercatcher",
155
+ "pelican",
156
+ "king penguin",
157
+ "albatross",
158
+ "grey whale",
159
+ "killer whale",
160
+ "dugong",
161
+ "sea lion",
162
+ "Chihuahua",
163
+ "Japanese Chin",
164
+ "Maltese",
165
+ "Pekingese",
166
+ "Shih Tzu",
167
+ "King Charles Spaniel",
168
+ "Papillon",
169
+ "toy terrier",
170
+ "Rhodesian Ridgeback",
171
+ "Afghan Hound",
172
+ "Basset Hound",
173
+ "Beagle",
174
+ "Bloodhound",
175
+ "Bluetick Coonhound",
176
+ "Black and Tan Coonhound",
177
+ "Treeing Walker Coonhound",
178
+ "English foxhound",
179
+ "Redbone Coonhound",
180
+ "borzoi",
181
+ "Irish Wolfhound",
182
+ "Italian Greyhound",
183
+ "Whippet",
184
+ "Ibizan Hound",
185
+ "Norwegian Elkhound",
186
+ "Otterhound",
187
+ "Saluki",
188
+ "Scottish Deerhound",
189
+ "Weimaraner",
190
+ "Staffordshire Bull Terrier",
191
+ "American Staffordshire Terrier",
192
+ "Bedlington Terrier",
193
+ "Border Terrier",
194
+ "Kerry Blue Terrier",
195
+ "Irish Terrier",
196
+ "Norfolk Terrier",
197
+ "Norwich Terrier",
198
+ "Yorkshire Terrier",
199
+ "Wire Fox Terrier",
200
+ "Lakeland Terrier",
201
+ "Sealyham Terrier",
202
+ "Airedale Terrier",
203
+ "Cairn Terrier",
204
+ "Australian Terrier",
205
+ "Dandie Dinmont Terrier",
206
+ "Boston Terrier",
207
+ "Miniature Schnauzer",
208
+ "Giant Schnauzer",
209
+ "Standard Schnauzer",
210
+ "Scottish Terrier",
211
+ "Tibetan Terrier",
212
+ "Australian Silky Terrier",
213
+ "Soft-coated Wheaten Terrier",
214
+ "West Highland White Terrier",
215
+ "Lhasa Apso",
216
+ "Flat-Coated Retriever",
217
+ "Curly-coated Retriever",
218
+ "Golden Retriever",
219
+ "Labrador Retriever",
220
+ "Chesapeake Bay Retriever",
221
+ "German Shorthaired Pointer",
222
+ "Vizsla",
223
+ "English Setter",
224
+ "Irish Setter",
225
+ "Gordon Setter",
226
+ "Brittany dog",
227
+ "Clumber Spaniel",
228
+ "English Springer Spaniel",
229
+ "Welsh Springer Spaniel",
230
+ "Cocker Spaniel",
231
+ "Sussex Spaniel",
232
+ "Irish Water Spaniel",
233
+ "Kuvasz",
234
+ "Schipperke",
235
+ "Groenendael dog",
236
+ "Malinois",
237
+ "Briard",
238
+ "Australian Kelpie",
239
+ "Komondor",
240
+ "Old English Sheepdog",
241
+ "Shetland Sheepdog",
242
+ "collie",
243
+ "Border Collie",
244
+ "Bouvier des Flandres dog",
245
+ "Rottweiler",
246
+ "German Shepherd Dog",
247
+ "Dobermann",
248
+ "Miniature Pinscher",
249
+ "Greater Swiss Mountain Dog",
250
+ "Bernese Mountain Dog",
251
+ "Appenzeller Sennenhund",
252
+ "Entlebucher Sennenhund",
253
+ "Boxer",
254
+ "Bullmastiff",
255
+ "Tibetan Mastiff",
256
+ "French Bulldog",
257
+ "Great Dane",
258
+ "St. Bernard",
259
+ "husky",
260
+ "Alaskan Malamute",
261
+ "Siberian Husky",
262
+ "Dalmatian",
263
+ "Affenpinscher",
264
+ "Basenji",
265
+ "pug",
266
+ "Leonberger",
267
+ "Newfoundland dog",
268
+ "Great Pyrenees dog",
269
+ "Samoyed",
270
+ "Pomeranian",
271
+ "Chow Chow",
272
+ "Keeshond",
273
+ "brussels griffon",
274
+ "Pembroke Welsh Corgi",
275
+ "Cardigan Welsh Corgi",
276
+ "Toy Poodle",
277
+ "Miniature Poodle",
278
+ "Standard Poodle",
279
+ "Mexican hairless dog (xoloitzcuintli)",
280
+ "grey wolf",
281
+ "Alaskan tundra wolf",
282
+ "red wolf or maned wolf",
283
+ "coyote",
284
+ "dingo",
285
+ "dhole",
286
+ "African wild dog",
287
+ "hyena",
288
+ "red fox",
289
+ "kit fox",
290
+ "Arctic fox",
291
+ "grey fox",
292
+ "tabby cat",
293
+ "tiger cat",
294
+ "Persian cat",
295
+ "Siamese cat",
296
+ "Egyptian Mau",
297
+ "cougar",
298
+ "lynx",
299
+ "leopard",
300
+ "snow leopard",
301
+ "jaguar",
302
+ "lion",
303
+ "tiger",
304
+ "cheetah",
305
+ "brown bear",
306
+ "American black bear",
307
+ "polar bear",
308
+ "sloth bear",
309
+ "mongoose",
310
+ "meerkat",
311
+ "tiger beetle",
312
+ "ladybug",
313
+ "ground beetle",
314
+ "longhorn beetle",
315
+ "leaf beetle",
316
+ "dung beetle",
317
+ "rhinoceros beetle",
318
+ "weevil",
319
+ "fly",
320
+ "bee",
321
+ "ant",
322
+ "grasshopper",
323
+ "cricket insect",
324
+ "stick insect",
325
+ "cockroach",
326
+ "praying mantis",
327
+ "cicada",
328
+ "leafhopper",
329
+ "lacewing",
330
+ "dragonfly",
331
+ "damselfly",
332
+ "red admiral butterfly",
333
+ "ringlet butterfly",
334
+ "monarch butterfly",
335
+ "small white butterfly",
336
+ "sulphur butterfly",
337
+ "gossamer-winged butterfly",
338
+ "starfish",
339
+ "sea urchin",
340
+ "sea cucumber",
341
+ "cottontail rabbit",
342
+ "hare",
343
+ "Angora rabbit",
344
+ "hamster",
345
+ "porcupine",
346
+ "fox squirrel",
347
+ "marmot",
348
+ "beaver",
349
+ "guinea pig",
350
+ "common sorrel horse",
351
+ "zebra",
352
+ "pig",
353
+ "wild boar",
354
+ "warthog",
355
+ "hippopotamus",
356
+ "ox",
357
+ "water buffalo",
358
+ "bison",
359
+ "ram (adult male sheep)",
360
+ "bighorn sheep",
361
+ "Alpine ibex",
362
+ "hartebeest",
363
+ "impala (antelope)",
364
+ "gazelle",
365
+ "arabian camel",
366
+ "llama",
367
+ "weasel",
368
+ "mink",
369
+ "European polecat",
370
+ "black-footed ferret",
371
+ "otter",
372
+ "skunk",
373
+ "badger",
374
+ "armadillo",
375
+ "three-toed sloth",
376
+ "orangutan",
377
+ "gorilla",
378
+ "chimpanzee",
379
+ "gibbon",
380
+ "siamang",
381
+ "guenon",
382
+ "patas monkey",
383
+ "baboon",
384
+ "macaque",
385
+ "langur",
386
+ "black-and-white colobus",
387
+ "proboscis monkey",
388
+ "marmoset",
389
+ "white-headed capuchin",
390
+ "howler monkey",
391
+ "titi monkey",
392
+ "Geoffroy's spider monkey",
393
+ "common squirrel monkey",
394
+ "ring-tailed lemur",
395
+ "indri",
396
+ "Asian elephant",
397
+ "African bush elephant",
398
+ "red panda",
399
+ "giant panda",
400
+ "snoek fish",
401
+ "eel",
402
+ "silver salmon",
403
+ "rock beauty fish",
404
+ "clownfish",
405
+ "sturgeon",
406
+ "gar fish",
407
+ "lionfish",
408
+ "pufferfish",
409
+ "abacus",
410
+ "abaya",
411
+ "academic gown",
412
+ "accordion",
413
+ "acoustic guitar",
414
+ "aircraft carrier",
415
+ "airliner",
416
+ "airship",
417
+ "altar",
418
+ "ambulance",
419
+ "amphibious vehicle",
420
+ "analog clock",
421
+ "apiary",
422
+ "apron",
423
+ "trash can",
424
+ "assault rifle",
425
+ "backpack",
426
+ "bakery",
427
+ "balance beam",
428
+ "balloon",
429
+ "ballpoint pen",
430
+ "Band-Aid",
431
+ "banjo",
432
+ "baluster / handrail",
433
+ "barbell",
434
+ "barber chair",
435
+ "barbershop",
436
+ "barn",
437
+ "barometer",
438
+ "barrel",
439
+ "wheelbarrow",
440
+ "baseball",
441
+ "basketball",
442
+ "bassinet",
443
+ "bassoon",
444
+ "swimming cap",
445
+ "bath towel",
446
+ "bathtub",
447
+ "station wagon",
448
+ "lighthouse",
449
+ "beaker",
450
+ "military hat (bearskin or shako)",
451
+ "beer bottle",
452
+ "beer glass",
453
+ "bell tower",
454
+ "baby bib",
455
+ "tandem bicycle",
456
+ "bikini",
457
+ "ring binder",
458
+ "binoculars",
459
+ "birdhouse",
460
+ "boathouse",
461
+ "bobsleigh",
462
+ "bolo tie",
463
+ "poke bonnet",
464
+ "bookcase",
465
+ "bookstore",
466
+ "bottle cap",
467
+ "hunting bow",
468
+ "bow tie",
469
+ "brass memorial plaque",
470
+ "bra",
471
+ "breakwater",
472
+ "breastplate",
473
+ "broom",
474
+ "bucket",
475
+ "buckle",
476
+ "bulletproof vest",
477
+ "high-speed train",
478
+ "butcher shop",
479
+ "taxicab",
480
+ "cauldron",
481
+ "candle",
482
+ "cannon",
483
+ "canoe",
484
+ "can opener",
485
+ "cardigan",
486
+ "car mirror",
487
+ "carousel",
488
+ "tool kit",
489
+ "cardboard box / carton",
490
+ "car wheel",
491
+ "automated teller machine",
492
+ "cassette",
493
+ "cassette player",
494
+ "castle",
495
+ "catamaran",
496
+ "CD player",
497
+ "cello",
498
+ "mobile phone",
499
+ "chain",
500
+ "chain-link fence",
501
+ "chain mail",
502
+ "chainsaw",
503
+ "storage chest",
504
+ "chiffonier",
505
+ "bell or wind chime",
506
+ "china cabinet",
507
+ "Christmas stocking",
508
+ "church",
509
+ "movie theater",
510
+ "cleaver",
511
+ "cliff dwelling",
512
+ "cloak",
513
+ "clogs",
514
+ "cocktail shaker",
515
+ "coffee mug",
516
+ "coffeemaker",
517
+ "spiral or coil",
518
+ "combination lock",
519
+ "computer keyboard",
520
+ "candy store",
521
+ "container ship",
522
+ "convertible",
523
+ "corkscrew",
524
+ "cornet",
525
+ "cowboy boot",
526
+ "cowboy hat",
527
+ "cradle",
528
+ "construction crane",
529
+ "crash helmet",
530
+ "crate",
531
+ "infant bed",
532
+ "Crock Pot",
533
+ "croquet ball",
534
+ "crutch",
535
+ "cuirass",
536
+ "dam",
537
+ "desk",
538
+ "desktop computer",
539
+ "rotary dial telephone",
540
+ "diaper",
541
+ "digital clock",
542
+ "digital watch",
543
+ "dining table",
544
+ "dishcloth",
545
+ "dishwasher",
546
+ "disc brake",
547
+ "dock",
548
+ "dog sled",
549
+ "dome",
550
+ "doormat",
551
+ "drilling rig",
552
+ "drum",
553
+ "drumstick",
554
+ "dumbbell",
555
+ "Dutch oven",
556
+ "electric fan",
557
+ "electric guitar",
558
+ "electric locomotive",
559
+ "entertainment center",
560
+ "envelope",
561
+ "espresso machine",
562
+ "face powder",
563
+ "feather boa",
564
+ "filing cabinet",
565
+ "fireboat",
566
+ "fire truck",
567
+ "fire screen",
568
+ "flagpole",
569
+ "flute",
570
+ "folding chair",
571
+ "football helmet",
572
+ "forklift",
573
+ "fountain",
574
+ "fountain pen",
575
+ "four-poster bed",
576
+ "freight car",
577
+ "French horn",
578
+ "frying pan",
579
+ "fur coat",
580
+ "garbage truck",
581
+ "gas mask or respirator",
582
+ "gas pump",
583
+ "goblet",
584
+ "go-kart",
585
+ "golf ball",
586
+ "golf cart",
587
+ "gondola",
588
+ "gong",
589
+ "gown",
590
+ "grand piano",
591
+ "greenhouse",
592
+ "radiator grille",
593
+ "grocery store",
594
+ "guillotine",
595
+ "hair clip",
596
+ "hair spray",
597
+ "half-track",
598
+ "hammer",
599
+ "hamper",
600
+ "hair dryer",
601
+ "hand-held computer",
602
+ "handkerchief",
603
+ "hard disk drive",
604
+ "harmonica",
605
+ "harp",
606
+ "combine harvester",
607
+ "hatchet",
608
+ "holster",
609
+ "home theater",
610
+ "honeycomb",
611
+ "hook",
612
+ "hoop skirt",
613
+ "gymnastic horizontal bar",
614
+ "horse-drawn vehicle",
615
+ "hourglass",
616
+ "iPod",
617
+ "clothes iron",
618
+ "carved pumpkin",
619
+ "jeans",
620
+ "jeep",
621
+ "T-shirt",
622
+ "jigsaw puzzle",
623
+ "rickshaw",
624
+ "joystick",
625
+ "kimono",
626
+ "knee pad",
627
+ "knot",
628
+ "lab coat",
629
+ "ladle",
630
+ "lampshade",
631
+ "laptop computer",
632
+ "lawn mower",
633
+ "lens cap",
634
+ "letter opener",
635
+ "library",
636
+ "lifeboat",
637
+ "lighter",
638
+ "limousine",
639
+ "ocean liner",
640
+ "lipstick",
641
+ "slip-on shoe",
642
+ "lotion",
643
+ "music speaker",
644
+ "loupe magnifying glass",
645
+ "sawmill",
646
+ "magnetic compass",
647
+ "messenger bag",
648
+ "mailbox",
649
+ "tights",
650
+ "one-piece bathing suit",
651
+ "manhole cover",
652
+ "maraca",
653
+ "marimba",
654
+ "mask",
655
+ "matchstick",
656
+ "maypole",
657
+ "maze",
658
+ "measuring cup",
659
+ "medicine cabinet",
660
+ "megalith",
661
+ "microphone",
662
+ "microwave oven",
663
+ "military uniform",
664
+ "milk can",
665
+ "minibus",
666
+ "miniskirt",
667
+ "minivan",
668
+ "missile",
669
+ "mitten",
670
+ "mixing bowl",
671
+ "mobile home",
672
+ "ford model t",
673
+ "modem",
674
+ "monastery",
675
+ "monitor",
676
+ "moped",
677
+ "mortar and pestle",
678
+ "graduation cap",
679
+ "mosque",
680
+ "mosquito net",
681
+ "vespa",
682
+ "mountain bike",
683
+ "tent",
684
+ "computer mouse",
685
+ "mousetrap",
686
+ "moving van",
687
+ "muzzle",
688
+ "metal nail",
689
+ "neck brace",
690
+ "necklace",
691
+ "baby pacifier",
692
+ "notebook computer",
693
+ "obelisk",
694
+ "oboe",
695
+ "ocarina",
696
+ "odometer",
697
+ "oil filter",
698
+ "pipe organ",
699
+ "oscilloscope",
700
+ "overskirt",
701
+ "bullock cart",
702
+ "oxygen mask",
703
+ "product packet / packaging",
704
+ "paddle",
705
+ "paddle wheel",
706
+ "padlock",
707
+ "paintbrush",
708
+ "pajamas",
709
+ "palace",
710
+ "pan flute",
711
+ "paper towel",
712
+ "parachute",
713
+ "parallel bars",
714
+ "park bench",
715
+ "parking meter",
716
+ "railroad car",
717
+ "patio",
718
+ "payphone",
719
+ "pedestal",
720
+ "pencil case",
721
+ "pencil sharpener",
722
+ "perfume",
723
+ "Petri dish",
724
+ "photocopier",
725
+ "plectrum",
726
+ "Pickelhaube",
727
+ "picket fence",
728
+ "pickup truck",
729
+ "pier",
730
+ "piggy bank",
731
+ "pill bottle",
732
+ "pillow",
733
+ "ping-pong ball",
734
+ "pinwheel",
735
+ "pirate ship",
736
+ "drink pitcher",
737
+ "block plane",
738
+ "planetarium",
739
+ "plastic bag",
740
+ "plate rack",
741
+ "farm plow",
742
+ "plunger",
743
+ "Polaroid camera",
744
+ "pole",
745
+ "police van",
746
+ "poncho",
747
+ "pool table",
748
+ "soda bottle",
749
+ "plant pot",
750
+ "potter's wheel",
751
+ "power drill",
752
+ "prayer rug",
753
+ "printer",
754
+ "prison",
755
+ "missile",
756
+ "projector",
757
+ "hockey puck",
758
+ "punching bag",
759
+ "purse",
760
+ "quill",
761
+ "quilt",
762
+ "race car",
763
+ "racket",
764
+ "radiator",
765
+ "radio",
766
+ "radio telescope",
767
+ "rain barrel",
768
+ "recreational vehicle",
769
+ "fishing casting reel",
770
+ "reflex camera",
771
+ "refrigerator",
772
+ "remote control",
773
+ "restaurant",
774
+ "revolver",
775
+ "rifle",
776
+ "rocking chair",
777
+ "rotisserie",
778
+ "eraser",
779
+ "rugby ball",
780
+ "ruler measuring stick",
781
+ "sneaker",
782
+ "safe",
783
+ "safety pin",
784
+ "salt shaker",
785
+ "sandal",
786
+ "sarong",
787
+ "saxophone",
788
+ "scabbard",
789
+ "weighing scale",
790
+ "school bus",
791
+ "schooner",
792
+ "scoreboard",
793
+ "CRT monitor",
794
+ "screw",
795
+ "screwdriver",
796
+ "seat belt",
797
+ "sewing machine",
798
+ "shield",
799
+ "shoe store",
800
+ "shoji screen / room divider",
801
+ "shopping basket",
802
+ "shopping cart",
803
+ "shovel",
804
+ "shower cap",
805
+ "shower curtain",
806
+ "ski",
807
+ "balaclava ski mask",
808
+ "sleeping bag",
809
+ "slide rule",
810
+ "sliding door",
811
+ "slot machine",
812
+ "snorkel",
813
+ "snowmobile",
814
+ "snowplow",
815
+ "soap dispenser",
816
+ "soccer ball",
817
+ "sock",
818
+ "solar thermal collector",
819
+ "sombrero",
820
+ "soup bowl",
821
+ "keyboard space bar",
822
+ "space heater",
823
+ "space shuttle",
824
+ "spatula",
825
+ "motorboat",
826
+ "spider web",
827
+ "spindle",
828
+ "sports car",
829
+ "spotlight",
830
+ "stage",
831
+ "steam locomotive",
832
+ "through arch bridge",
833
+ "steel drum",
834
+ "stethoscope",
835
+ "scarf",
836
+ "stone wall",
837
+ "stopwatch",
838
+ "stove",
839
+ "strainer",
840
+ "tram",
841
+ "stretcher",
842
+ "couch",
843
+ "stupa",
844
+ "submarine",
845
+ "suit",
846
+ "sundial",
847
+ "sunglasses",
848
+ "sunglasses",
849
+ "sunscreen",
850
+ "suspension bridge",
851
+ "mop",
852
+ "sweatshirt",
853
+ "swim trunks / shorts",
854
+ "swing",
855
+ "electrical switch",
856
+ "syringe",
857
+ "table lamp",
858
+ "tank",
859
+ "tape player",
860
+ "teapot",
861
+ "teddy bear",
862
+ "television",
863
+ "tennis ball",
864
+ "thatched roof",
865
+ "front curtain",
866
+ "thimble",
867
+ "threshing machine",
868
+ "throne",
869
+ "tile roof",
870
+ "toaster",
871
+ "tobacco shop",
872
+ "toilet seat",
873
+ "torch",
874
+ "totem pole",
875
+ "tow truck",
876
+ "toy store",
877
+ "tractor",
878
+ "semi-trailer truck",
879
+ "tray",
880
+ "trench coat",
881
+ "tricycle",
882
+ "trimaran",
883
+ "tripod",
884
+ "triumphal arch",
885
+ "trolleybus",
886
+ "trombone",
887
+ "hot tub",
888
+ "turnstile",
889
+ "typewriter keyboard",
890
+ "umbrella",
891
+ "unicycle",
892
+ "upright piano",
893
+ "vacuum cleaner",
894
+ "vase",
895
+ "vaulted or arched ceiling",
896
+ "velvet fabric",
897
+ "vending machine",
898
+ "vestment",
899
+ "viaduct",
900
+ "violin",
901
+ "volleyball",
902
+ "waffle iron",
903
+ "wall clock",
904
+ "wallet",
905
+ "wardrobe",
906
+ "military aircraft",
907
+ "sink",
908
+ "washing machine",
909
+ "water bottle",
910
+ "water jug",
911
+ "water tower",
912
+ "whiskey jug",
913
+ "whistle",
914
+ "hair wig",
915
+ "window screen",
916
+ "window shade",
917
+ "Windsor tie",
918
+ "wine bottle",
919
+ "airplane wing",
920
+ "wok",
921
+ "wooden spoon",
922
+ "wool",
923
+ "split-rail fence",
924
+ "shipwreck",
925
+ "sailboat",
926
+ "yurt",
927
+ "website",
928
+ "comic book",
929
+ "crossword",
930
+ "traffic or street sign",
931
+ "traffic light",
932
+ "dust jacket",
933
+ "menu",
934
+ "plate",
935
+ "guacamole",
936
+ "consomme",
937
+ "hot pot",
938
+ "trifle",
939
+ "ice cream",
940
+ "popsicle",
941
+ "baguette",
942
+ "bagel",
943
+ "pretzel",
944
+ "cheeseburger",
945
+ "hot dog",
946
+ "mashed potatoes",
947
+ "cabbage",
948
+ "broccoli",
949
+ "cauliflower",
950
+ "zucchini",
951
+ "spaghetti squash",
952
+ "acorn squash",
953
+ "butternut squash",
954
+ "cucumber",
955
+ "artichoke",
956
+ "bell pepper",
957
+ "cardoon",
958
+ "mushroom",
959
+ "Granny Smith apple",
960
+ "strawberry",
961
+ "orange",
962
+ "lemon",
963
+ "fig",
964
+ "pineapple",
965
+ "banana",
966
+ "jackfruit",
967
+ "cherimoya (custard apple)",
968
+ "pomegranate",
969
+ "hay",
970
+ "carbonara",
971
+ "chocolate syrup",
972
+ "dough",
973
+ "meatloaf",
974
+ "pizza",
975
+ "pot pie",
976
+ "burrito",
977
+ "red wine",
978
+ "espresso",
979
+ "tea cup",
980
+ "eggnog",
981
+ "mountain",
982
+ "bubble",
983
+ "cliff",
984
+ "coral reef",
985
+ "geyser",
986
+ "lakeshore",
987
+ "promontory",
988
+ "sandbar",
989
+ "beach",
990
+ "valley",
991
+ "volcano",
992
+ "baseball player",
993
+ "bridegroom",
994
+ "scuba diver",
995
+ "rapeseed",
996
+ "daisy",
997
+ "yellow lady's slipper",
998
+ "corn",
999
+ "acorn",
1000
+ "rose hip",
1001
+ "horse chestnut seed",
1002
+ "coral fungus",
1003
+ "agaric",
1004
+ "gyromitra",
1005
+ "stinkhorn mushroom",
1006
+ "earth star fungus",
1007
+ "hen of the woods mushroom",
1008
+ "bolete",
1009
+ "corn cob",
1010
+ "toilet paper",
1011
+ ]
1012
+
1013
+
1014
+ openai_imagenet_template = [
1015
+ lambda c: f"a bad photo of a {c}.",
1016
+ lambda c: f"a photo of many {c}.",
1017
+ lambda c: f"a sculpture of a {c}.",
1018
+ lambda c: f"a photo of the hard to see {c}.",
1019
+ lambda c: f"a low resolution photo of the {c}.",
1020
+ lambda c: f"a rendering of a {c}.",
1021
+ lambda c: f"graffiti of a {c}.",
1022
+ lambda c: f"a bad photo of the {c}.",
1023
+ lambda c: f"a cropped photo of the {c}.",
1024
+ lambda c: f"a tattoo of a {c}.",
1025
+ lambda c: f"the embroidered {c}.",
1026
+ lambda c: f"a photo of a hard to see {c}.",
1027
+ lambda c: f"a bright photo of a {c}.",
1028
+ lambda c: f"a photo of a clean {c}.",
1029
+ lambda c: f"a photo of a dirty {c}.",
1030
+ lambda c: f"a dark photo of the {c}.",
1031
+ lambda c: f"a drawing of a {c}.",
1032
+ lambda c: f"a photo of my {c}.",
1033
+ lambda c: f"the plastic {c}.",
1034
+ lambda c: f"a photo of the cool {c}.",
1035
+ lambda c: f"a close-up photo of a {c}.",
1036
+ lambda c: f"a black and white photo of the {c}.",
1037
+ lambda c: f"a painting of the {c}.",
1038
+ lambda c: f"a painting of a {c}.",
1039
+ lambda c: f"a pixelated photo of the {c}.",
1040
+ lambda c: f"a sculpture of the {c}.",
1041
+ lambda c: f"a bright photo of the {c}.",
1042
+ lambda c: f"a cropped photo of a {c}.",
1043
+ lambda c: f"a plastic {c}.",
1044
+ lambda c: f"a photo of the dirty {c}.",
1045
+ lambda c: f"a jpeg corrupted photo of a {c}.",
1046
+ lambda c: f"a blurry photo of the {c}.",
1047
+ lambda c: f"a photo of the {c}.",
1048
+ lambda c: f"a good photo of the {c}.",
1049
+ lambda c: f"a rendering of the {c}.",
1050
+ lambda c: f"a {c} in a video game.",
1051
+ lambda c: f"a photo of one {c}.",
1052
+ lambda c: f"a doodle of a {c}.",
1053
+ lambda c: f"a close-up photo of the {c}.",
1054
+ lambda c: f"a photo of a {c}.",
1055
+ lambda c: f"the origami {c}.",
1056
+ lambda c: f"the {c} in a video game.",
1057
+ lambda c: f"a sketch of a {c}.",
1058
+ lambda c: f"a doodle of the {c}.",
1059
+ lambda c: f"a origami {c}.",
1060
+ lambda c: f"a low resolution photo of a {c}.",
1061
+ lambda c: f"the toy {c}.",
1062
+ lambda c: f"a rendition of the {c}.",
1063
+ lambda c: f"a photo of the clean {c}.",
1064
+ lambda c: f"a photo of a large {c}.",
1065
+ lambda c: f"a rendition of a {c}.",
1066
+ lambda c: f"a photo of a nice {c}.",
1067
+ lambda c: f"a photo of a weird {c}.",
1068
+ lambda c: f"a blurry photo of a {c}.",
1069
+ lambda c: f"a cartoon {c}.",
1070
+ lambda c: f"art of a {c}.",
1071
+ lambda c: f"a sketch of the {c}.",
1072
+ lambda c: f"a embroidered {c}.",
1073
+ lambda c: f"a pixelated photo of a {c}.",
1074
+ lambda c: f"itap of the {c}.",
1075
+ lambda c: f"a jpeg corrupted photo of the {c}.",
1076
+ lambda c: f"a good photo of a {c}.",
1077
+ lambda c: f"a plushie {c}.",
1078
+ lambda c: f"a photo of the nice {c}.",
1079
+ lambda c: f"a photo of the small {c}.",
1080
+ lambda c: f"a photo of the weird {c}.",
1081
+ lambda c: f"the cartoon {c}.",
1082
+ lambda c: f"art of the {c}.",
1083
+ lambda c: f"a drawing of the {c}.",
1084
+ lambda c: f"a photo of the large {c}.",
1085
+ lambda c: f"a black and white photo of a {c}.",
1086
+ lambda c: f"the plushie {c}.",
1087
+ lambda c: f"a dark photo of a {c}.",
1088
+ lambda c: f"itap of a {c}.",
1089
+ lambda c: f"graffiti of the {c}.",
1090
+ lambda c: f"a toy {c}.",
1091
+ lambda c: f"itap of my {c}.",
1092
+ lambda c: f"a photo of a cool {c}.",
1093
+ lambda c: f"a photo of a small {c}.",
1094
+ lambda c: f"a tattoo of the {c}.",
1095
+ ]
multimodal/examples/flava/data/transforms.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import random
8
+ from functools import partial
9
+ from typing import Any, Callable, Optional
10
+
11
+ import torch
12
+ from torchmultimodal.transforms.flava_transform import FLAVAImageTransform
13
+ from torchvision import transforms
14
+ from transformers import BertTokenizer
15
+
16
+
17
+ IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
18
+ IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
19
+ IMAGE_DEFAULT_SIZE = (224, 224)
20
+ VL_MAX_LENGTH_DEFAULT = 77
21
+ TEXT_MAX_LENGTH_DEFAULT = 512
22
+ TEXT_DEFAULT_TOKENIZER = "bert-base-uncased"
23
+ TEXT_WHOLE_WORD_MASK_TOKENIZER = "bert-large-uncased-whole-word-masking"
24
+
25
+
26
+ def encode_text(text, tokenizer, *args, **kwargs):
27
+ return tokenizer(text, *args, **kwargs)
28
+
29
+
30
+ def encode_text_batch(
31
+ batch, tokenizer, text_columns, return_batch=False, *args, **kwargs
32
+ ):
33
+ texts = [batch[column] for column in text_columns]
34
+ tokens = tokenizer(*texts, *args, **kwargs)
35
+ if return_batch:
36
+ batch.update(tokens)
37
+ return batch
38
+ return tokens
39
+
40
+
41
+ def transform_image_dict(transform, image_dict, *args, **kwargs):
42
+ return {"image": transform(image_dict["image"], *args, **kwargs)}
43
+
44
+
45
+ def default_torchvision_transforms(
46
+ size=IMAGE_DEFAULT_SIZE,
47
+ mean=IMAGENET_DEFAULT_MEAN,
48
+ std=IMAGENET_DEFAULT_STD,
49
+ use_dict=False,
50
+ ):
51
+ transform = transforms.Compose(
52
+ [
53
+ transforms.Resize(size),
54
+ transforms.ToTensor(),
55
+ transforms.Normalize(
56
+ mean=mean,
57
+ std=std,
58
+ ),
59
+ ]
60
+ )
61
+
62
+ if use_dict:
63
+ transform = partial(transform_image_dict, transform=transform)
64
+
65
+ return transform, transform
66
+
67
+
68
+ def default_image_pretraining_transforms():
69
+ return FLAVAImageTransform(), FLAVAImageTransform(is_train=False)
70
+
71
+
72
+ def default_text_transform(
73
+ text_tokenizer: Optional[Callable] = None,
74
+ max_text_length: int = TEXT_MAX_LENGTH_DEFAULT,
75
+ **kwargs: Any,
76
+ ):
77
+ if text_tokenizer is None:
78
+ text_tokenizer = BertTokenizer.from_pretrained(TEXT_DEFAULT_TOKENIZER)
79
+
80
+ text_transform = partial(
81
+ encode_text,
82
+ tokenizer=text_tokenizer,
83
+ padding="max_length",
84
+ max_length=max_text_length,
85
+ truncation=True,
86
+ return_tensors="pt",
87
+ return_special_tokens_mask=True,
88
+ )
89
+
90
+ return text_transform
91
+
92
+
93
+ def default_vl_text_transform(
94
+ text_tokenizer: Optional[Callable] = None,
95
+ max_text_length: int = VL_MAX_LENGTH_DEFAULT,
96
+ **kwargs: Any,
97
+ ):
98
+ if text_tokenizer is None:
99
+ text_tokenizer = BertTokenizer.from_pretrained(TEXT_WHOLE_WORD_MASK_TOKENIZER)
100
+ return default_text_transform(text_tokenizer, max_text_length=max_text_length)
101
+
102
+
103
+ def pad_batch(batch, batch_size):
104
+ for item in batch.keys():
105
+ if isinstance(batch[item], torch.Tensor):
106
+ diff = batch_size - batch[item].size(0)
107
+ pad = batch[item][-diff:].detach().clone()
108
+ batch[item] = torch.cat([batch[item], pad], dim=0)
109
+ return batch
110
+
111
+
112
+ class VLTransform:
113
+ def __init__(self, image_transform, text_transform):
114
+ self.image_transform = image_transform
115
+ self.text_transform = text_transform
116
+
117
+ def __call__(self, info, dataset, itm_probability):
118
+ output = {}
119
+ text = info["text"]
120
+ image = info["image"]
121
+ if itm_probability > 0:
122
+ output["itm_labels"] = torch.ones((1), dtype=torch.long)
123
+
124
+ if random.random() < itm_probability:
125
+ while text == info["text"]:
126
+ text = dataset.select([random.randint(0, len(dataset) - 1)])[0]["text"]
127
+ output["itm_labels"] = torch.zeros((1), dtype=torch.long)
128
+
129
+ output.update(self.image_transform(image))
130
+ output.update(self.text_transform(text))
131
+ return output
multimodal/examples/flava/data/utils.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import time
8
+ from concurrent.futures import ThreadPoolExecutor
9
+ from functools import partial
10
+ from typing import List
11
+
12
+ import requests
13
+ from datasets import concatenate_datasets, load_dataset
14
+ from datasets.utils.file_utils import get_datasets_user_agent
15
+ from flava.definitions import HFDatasetInfo
16
+ from PIL import Image, UnidentifiedImageError
17
+
18
+
19
+ DATASETS_USER_AGENT = get_datasets_user_agent()
20
+
21
+
22
+ def build_datasets_from_info(dataset_infos: List[HFDatasetInfo], split: str = "train"):
23
+ dataset_list = []
24
+ for dataset_info in dataset_infos:
25
+ current_dataset = load_dataset(
26
+ dataset_info.key,
27
+ dataset_info.subset,
28
+ split=dataset_info.split_key_mapping[split],
29
+ use_auth_token=True,
30
+ **dataset_info.extra_kwargs,
31
+ )
32
+ if dataset_info.remove_columns is not None:
33
+ current_dataset = current_dataset.remove_columns(
34
+ dataset_info.remove_columns
35
+ )
36
+ if dataset_info.rename_columns is not None:
37
+ for rename in dataset_info.rename_columns:
38
+ current_dataset = current_dataset.rename_column(rename[0], rename[1])
39
+
40
+ dataset_list.append(current_dataset)
41
+
42
+ return concatenate_datasets(dataset_list)
43
+
44
+
45
+ def fetch_single_image(image_url, timeout, retries=0, sleep_timer=0):
46
+ for _ in range(retries + 1):
47
+ try:
48
+ image = Image.open(
49
+ requests.get(
50
+ image_url,
51
+ stream=True,
52
+ headers={"user-agent": DATASETS_USER_AGENT},
53
+ timeout=timeout,
54
+ ).raw
55
+ )
56
+ break
57
+ except (requests.exceptions.ConnectionError, UnidentifiedImageError):
58
+ image = None
59
+ time.sleep(sleep_timer)
60
+
61
+ return image
62
+
63
+
64
+ def fetch_images(batch, num_threads, timeout=None, retries=0, sleep_timer=0):
65
+ if "image" in batch:
66
+ # This dataset already has "image" defined.
67
+ return batch
68
+ with ThreadPoolExecutor(max_workers=num_threads) as executor:
69
+ batch["image"] = list(
70
+ executor.map(
71
+ partial(
72
+ fetch_single_image,
73
+ timeout=timeout,
74
+ retries=retries,
75
+ sleep_timer=sleep_timer,
76
+ ),
77
+ batch["image_url"],
78
+ )
79
+ )
80
+ return batch
multimodal/examples/flava/native/README.md ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Usage Instructions
2
+
3
+ This is a lightweight native pytorch implementation to run scaling studies on the FLAVA model. The original code is located at: [`examples/flava/train.py`](https://github.com/facebookresearch/multimodal/blob/main/examples/flava/train.py)
4
+
5
+ ## Prerequisites
6
+
7
+ - Install torchmultimodal library [from source](https://github.com/facebookresearch/multimodal/blob/main/README.md#building-from-source)
8
+ - `cd multimodal/examples`
9
+ - `pip install -r flava/requirements.txt`
10
+
11
+ ## Training
12
+
13
+ ### Configuration
14
+
15
+ Configuration presets for various model sizes can be found at: `examples/flava/native/configs`
16
+
17
+ Some config settings that are relevant for scaling: (local) `batch_size`, `activation_checkpointing`, `strategy`.
18
+
19
+ Configs can be overridden through command line, for example: `python -m flava.native.train config=flava/native/configs/pretrain_debug.yaml training.batch_size=8 training.enable_amp=True training.activation_checkpointing=True training.strategy=fsdp`
20
+
21
+ ### Running
22
+
23
+
24
+ Using [`torchrun`](https://pytorch.org/docs/stable/elastic/run.html):
25
+
26
+ **Single node**
27
+
28
+ `NUM_GPUS=8; torchrun --nproc_per_node=$NUM_GPUS -m flava.native.train config=flava/native/configs/pretrain_debug.yaml`
29
+
30
+ **Multiple nodes (using slurm)**
31
+
32
+ Create a `run.slurm` file:
33
+
34
+ ```bash
35
+ RDZV_ENDPOINT=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)
36
+
37
+ srun torchrun --nnodes=$SLURM_NNODES --nproc_per_node=$SLURM_GPUS_PER_TASK --rdzv_id=$SLURM_JOB_ID --rdzv_backend=c10d --rdzv_endpoint=$RDZV_ENDPOINT --max_restarts 0 -m flava.native.train config=flava/native/configs/pretrain_debug.yaml
38
+ $@
39
+ ```
40
+
41
+ Run in terminal:
42
+
43
+ `sbatch --partition=[PARTITION] --nodes=[NUM_NODES] --gpus-per-task=[NUM_GPUS_PER_NODE] run.slurm`
multimodal/examples/flava/native/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
multimodal/examples/flava/native/configs/1.8b.yaml ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ training:
2
+ strategy: fsdp # can be changed to ddp or fsdp
3
+ seed: 1337
4
+
5
+ batch_size: 8
6
+ num_workers: 4
7
+ prefetch_factor: 3
8
+
9
+ optimizer:
10
+ learning_rate: 1e-3
11
+ adam_eps: 1e-8
12
+ adam_weight_decay: 0.1
13
+ adam_betas: [0.9, 0.999]
14
+
15
+ warmup_steps: 10000
16
+ max_steps: 100000
17
+
18
+ validation_steps: 5000
19
+ log_interval: 10
20
+
21
+ enable_tf32: True
22
+ enable_amp: True
23
+ half_precision_format: "bfloat16" # or float16
24
+ enable_half_reduce_in_fsdp: True # handles the reduction across devices in half precision
25
+
26
+ activation_checkpointing: True
27
+
28
+ datasets:
29
+ _target_: flava.definitions.TrainingDatasetsInfo
30
+ selected:
31
+ - image
32
+ - vl
33
+ - text
34
+ image:
35
+ _target_: flava.definitions.TrainingSingleDatasetInfo
36
+ train:
37
+ - _target_: flava.definitions.HFDatasetInfo
38
+ key: imagenet-1k
39
+ subset: default
40
+ text:
41
+ _target_: flava.definitions.TrainingSingleDatasetInfo
42
+ train:
43
+ - _target_: flava.definitions.HFDatasetInfo
44
+ key: wikitext
45
+ subset: wikitext-103-raw-v1
46
+ datamodule_extra_kwargs:
47
+ text_columns: ["text"]
48
+ vl:
49
+ _target_: flava.definitions.TrainingSingleDatasetInfo
50
+ train:
51
+ - _target_: flava.definitions.HFDatasetInfo
52
+ key: red_caps
53
+ subset: backpacking
54
+ rename_columns:
55
+ - ["caption", "text"]
56
+ val:
57
+ - _target_: flava.definitions.HFDatasetInfo
58
+ key: red_caps
59
+ subset: backpacking
60
+ rename_columns:
61
+ - ["caption", "text"]
62
+ split_key_mapping:
63
+ validation: train
64
+
65
+ model:
66
+ image_num_hidden_layers: 32
67
+ image_hidden_size: 1280
68
+ image_intermediate_size: 5120
69
+ image_num_attention_heads: 16
70
+
71
+ text_num_hidden_layers: 32
72
+ text_hidden_size: 1280
73
+ text_intermediate_size: 5120
74
+ text_num_attention_heads: 16
75
+
76
+ multimodal_num_hidden_layers: 16
77
+ multimodal_hidden_size: 1280
78
+ multimodal_intermediate_size: 5120
79
+ multimodal_num_attention_heads: 16
multimodal/examples/flava/native/configs/10b.yaml ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ training:
2
+ strategy: fsdp # can be changed to ddp or fsdp
3
+ seed: 1337
4
+
5
+ batch_size: 8
6
+ num_workers: 4
7
+ prefetch_factor: 3
8
+
9
+ optimizer:
10
+ learning_rate: 1e-3
11
+ adam_eps: 1e-8
12
+ adam_weight_decay: 0.1
13
+ adam_betas: [0.9, 0.999]
14
+
15
+ warmup_steps: 10000
16
+ max_steps: 100000
17
+
18
+ validation_steps: 5000
19
+ log_interval: 10
20
+
21
+ enable_tf32: True
22
+ enable_amp: True
23
+ half_precision_format: "bfloat16" # or float16
24
+ enable_half_reduce_in_fsdp: True # handles the reduction across devices in half precision
25
+
26
+ activation_checkpointing: True
27
+
28
+ datasets:
29
+ _target_: flava.definitions.TrainingDatasetsInfo
30
+ selected:
31
+ - image
32
+ - vl
33
+ - text
34
+ image:
35
+ _target_: flava.definitions.TrainingSingleDatasetInfo
36
+ train:
37
+ - _target_: flava.definitions.HFDatasetInfo
38
+ key: imagenet-1k
39
+ subset: default
40
+ text:
41
+ _target_: flava.definitions.TrainingSingleDatasetInfo
42
+ train:
43
+ - _target_: flava.definitions.HFDatasetInfo
44
+ key: wikitext
45
+ subset: wikitext-103-raw-v1
46
+ datamodule_extra_kwargs:
47
+ text_columns: ["text"]
48
+ vl:
49
+ _target_: flava.definitions.TrainingSingleDatasetInfo
50
+ train:
51
+ - _target_: flava.definitions.HFDatasetInfo
52
+ key: red_caps
53
+ subset: backpacking
54
+ rename_columns:
55
+ - ["caption", "text"]
56
+ val:
57
+ - _target_: flava.definitions.HFDatasetInfo
58
+ key: red_caps
59
+ subset: backpacking
60
+ rename_columns:
61
+ - ["caption", "text"]
62
+ split_key_mapping:
63
+ validation: train
64
+
65
+
66
+ model:
67
+ image_num_hidden_layers: 64
68
+ image_hidden_size: 2048
69
+ image_intermediate_size: 10240
70
+ image_num_attention_heads: 16
71
+
72
+ text_num_hidden_layers: 64
73
+ text_hidden_size: 2048
74
+ text_intermediate_size: 10240
75
+ text_num_attention_heads: 16
76
+
77
+ multimodal_num_hidden_layers: 40
78
+ multimodal_hidden_size: 2048
79
+ multimodal_intermediate_size: 10240
80
+ multimodal_num_attention_heads: 16
multimodal/examples/flava/native/configs/2.7b.yaml ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ training:
2
+ strategy: fsdp # can be changed to ddp or fsdp
3
+ seed: 1337
4
+
5
+ batch_size: 8
6
+ num_workers: 4
7
+ prefetch_factor: 3
8
+
9
+ optimizer:
10
+ learning_rate: 1e-3
11
+ adam_eps: 1e-8
12
+ adam_weight_decay: 0.1
13
+ adam_betas: [0.9, 0.999]
14
+
15
+ warmup_steps: 10000
16
+ max_steps: 100000
17
+
18
+ validation_steps: 5000
19
+ log_interval: 10
20
+
21
+ enable_tf32: True
22
+ enable_amp: True
23
+ half_precision_format: "bfloat16" # or float16
24
+ enable_half_reduce_in_fsdp: True # handles the reduction across devices in half precision
25
+
26
+ activation_checkpointing: True
27
+
28
+ datasets:
29
+ _target_: flava.definitions.TrainingDatasetsInfo
30
+ selected:
31
+ - image
32
+ - vl
33
+ - text
34
+ image:
35
+ _target_: flava.definitions.TrainingSingleDatasetInfo
36
+ train:
37
+ - _target_: flava.definitions.HFDatasetInfo
38
+ key: imagenet-1k
39
+ subset: default
40
+ text:
41
+ _target_: flava.definitions.TrainingSingleDatasetInfo
42
+ train:
43
+ - _target_: flava.definitions.HFDatasetInfo
44
+ key: wikitext
45
+ subset: wikitext-103-raw-v1
46
+ datamodule_extra_kwargs:
47
+ text_columns: ["text"]
48
+ vl:
49
+ _target_: flava.definitions.TrainingSingleDatasetInfo
50
+ train:
51
+ - _target_: flava.definitions.HFDatasetInfo
52
+ key: red_caps
53
+ subset: backpacking
54
+ rename_columns:
55
+ - ["caption", "text"]
56
+ val:
57
+ - _target_: flava.definitions.HFDatasetInfo
58
+ key: red_caps
59
+ subset: backpacking
60
+ rename_columns:
61
+ - ["caption", "text"]
62
+ split_key_mapping:
63
+ validation: train
64
+
65
+ model:
66
+ image_num_hidden_layers: 40
67
+ image_hidden_size: 1408
68
+ image_intermediate_size: 6144
69
+ image_num_attention_heads: 16
70
+
71
+ text_num_hidden_layers: 40
72
+ text_hidden_size: 1408
73
+ text_intermediate_size: 6144
74
+ text_num_attention_heads: 16
75
+
76
+ multimodal_num_hidden_layers: 20
77
+ multimodal_hidden_size: 1408
78
+ multimodal_intermediate_size: 6144
79
+ multimodal_num_attention_heads: 16
multimodal/examples/flava/native/configs/4.8b.yaml ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ training:
2
+ strategy: fsdp # can be changed to ddp or fsdp
3
+ seed: 1337
4
+
5
+ batch_size: 12
6
+ num_workers: 4
7
+ prefetch_factor: 3
8
+
9
+ optimizer:
10
+ learning_rate: 1e-3
11
+ adam_eps: 1e-8
12
+ adam_weight_decay: 0.1
13
+ adam_betas: [0.9, 0.999]
14
+
15
+ warmup_steps: 10000
16
+ max_steps: 100000
17
+
18
+ validation_steps: 5000
19
+ log_interval: 10
20
+
21
+ enable_tf32: True
22
+ enable_amp: True
23
+ half_precision_format: "bfloat16" # or float16
24
+ enable_half_reduce_in_fsdp: True # handles the reduction across devices in half precision
25
+
26
+ activation_checkpointing: True
27
+
28
+ datasets:
29
+ _target_: flava.definitions.TrainingDatasetsInfo
30
+ selected:
31
+ - image
32
+ - vl
33
+ - text
34
+ image:
35
+ _target_: flava.definitions.TrainingSingleDatasetInfo
36
+ train:
37
+ - _target_: flava.definitions.HFDatasetInfo
38
+ key: imagenet-1k
39
+ subset: default
40
+ text:
41
+ _target_: flava.definitions.TrainingSingleDatasetInfo
42
+ train:
43
+ - _target_: flava.definitions.HFDatasetInfo
44
+ key: wikitext
45
+ subset: wikitext-103-raw-v1
46
+ datamodule_extra_kwargs:
47
+ text_columns: ["text"]
48
+ vl:
49
+ _target_: flava.definitions.TrainingSingleDatasetInfo
50
+ train:
51
+ - _target_: flava.definitions.HFDatasetInfo
52
+ key: red_caps
53
+ subset: backpacking
54
+ rename_columns:
55
+ - ["caption", "text"]
56
+ val:
57
+ - _target_: flava.definitions.HFDatasetInfo
58
+ key: red_caps
59
+ subset: backpacking
60
+ rename_columns:
61
+ - ["caption", "text"]
62
+ split_key_mapping:
63
+ validation: train
64
+
65
+ model:
66
+ image_num_hidden_layers: 48
67
+ image_hidden_size: 1664
68
+ image_intermediate_size: 8192
69
+ image_num_attention_heads: 16
70
+
71
+ text_num_hidden_layers: 48
72
+ text_hidden_size: 1664
73
+ text_intermediate_size: 8192
74
+ text_num_attention_heads: 16
75
+
76
+ multimodal_num_hidden_layers: 24
77
+ multimodal_hidden_size: 1664
78
+ multimodal_intermediate_size: 8192
79
+ multimodal_num_attention_heads: 16
multimodal/examples/flava/native/configs/900m.yaml ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ training:
2
+ strategy: ddp # can be changed to ddp or fsdp
3
+ seed: 1337
4
+
5
+ batch_size: 8
6
+ num_workers: 4
7
+ prefetch_factor: 3
8
+
9
+ optimizer:
10
+ learning_rate: 1e-3
11
+ adam_eps: 1e-8
12
+ adam_weight_decay: 0.1
13
+ adam_betas: [0.9, 0.999]
14
+
15
+ warmup_steps: 10000
16
+ max_steps: 100000
17
+
18
+ validation_steps: 5000
19
+ log_interval: 10
20
+
21
+ enable_tf32: True
22
+ enable_amp: True
23
+ half_precision_format: "bfloat16" # or float16
24
+ enable_half_reduce_in_fsdp: True # handles the reduction across devices in half precision
25
+
26
+ activation_checkpointing: True
27
+
28
+ datasets:
29
+ _target_: flava.definitions.TrainingDatasetsInfo
30
+ selected:
31
+ - image
32
+ - vl
33
+ - text
34
+ image:
35
+ _target_: flava.definitions.TrainingSingleDatasetInfo
36
+ train:
37
+ - _target_: flava.definitions.HFDatasetInfo
38
+ key: imagenet-1k
39
+ subset: default
40
+ text:
41
+ _target_: flava.definitions.TrainingSingleDatasetInfo
42
+ train:
43
+ - _target_: flava.definitions.HFDatasetInfo
44
+ key: wikitext
45
+ subset: wikitext-103-raw-v1
46
+ datamodule_extra_kwargs:
47
+ text_columns: ["text"]
48
+ vl:
49
+ _target_: flava.definitions.TrainingSingleDatasetInfo
50
+ train:
51
+ - _target_: flava.definitions.HFDatasetInfo
52
+ key: red_caps
53
+ subset: backpacking
54
+ rename_columns:
55
+ - ["caption", "text"]
56
+ val:
57
+ - _target_: flava.definitions.HFDatasetInfo
58
+ key: red_caps
59
+ subset: backpacking
60
+ rename_columns:
61
+ - ["caption", "text"]
62
+ split_key_mapping:
63
+ validation: train
64
+
65
+ model:
66
+ image_num_hidden_layers: 24
67
+ image_hidden_size: 1024
68
+ image_intermediate_size: 4096
69
+ image_num_attention_heads: 16
70
+
71
+ text_num_hidden_layers: 24
72
+ text_hidden_size: 1024
73
+ text_intermediate_size: 4096
74
+ text_num_attention_heads: 16
75
+
76
+ multimodal_num_hidden_layers: 12
77
+ multimodal_hidden_size: 1024
78
+ multimodal_intermediate_size: 4096
79
+ multimodal_num_attention_heads: 16
multimodal/examples/flava/native/configs/pretrain_debug.yaml ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ training:
2
+ strategy: ddp # can be changed to ddp or fsdp
3
+ seed: 1337
4
+
5
+ batch_size: 8
6
+ num_workers: 4
7
+ prefetch_factor: 3
8
+
9
+ optimizer:
10
+ learning_rate: 1e-3
11
+ adam_eps: 1e-8
12
+ adam_weight_decay: 0.1
13
+ adam_betas: [0.9, 0.999]
14
+
15
+ warmup_steps: 10000
16
+ max_steps: 100000
17
+
18
+ validation_steps: 5000
19
+ log_interval: 10
20
+
21
+ enable_tf32: True
22
+ enable_amp: True
23
+ half_precision_format: "bfloat16" # or float16
24
+ enable_half_reduce_in_fsdp: True # handles the reduction across devices in half precision
25
+
26
+ activation_checkpointing: False
27
+
28
+ datasets:
29
+ _target_: flava.definitions.TrainingDatasetsInfo
30
+ selected:
31
+ - image
32
+ - vl
33
+ - text
34
+ image:
35
+ _target_: flava.definitions.TrainingSingleDatasetInfo
36
+ train:
37
+ - _target_: flava.definitions.HFDatasetInfo
38
+ key: imagenet-1k
39
+ subset: default
40
+ text:
41
+ _target_: flava.definitions.TrainingSingleDatasetInfo
42
+ train:
43
+ - _target_: flava.definitions.HFDatasetInfo
44
+ key: wikitext
45
+ subset: wikitext-103-raw-v1
46
+ datamodule_extra_kwargs:
47
+ text_columns: ["text"]
48
+ vl:
49
+ _target_: flava.definitions.TrainingSingleDatasetInfo
50
+ train:
51
+ - _target_: flava.definitions.HFDatasetInfo
52
+ key: red_caps
53
+ subset: backpacking
54
+ rename_columns:
55
+ - ["caption", "text"]
56
+ val:
57
+ - _target_: flava.definitions.HFDatasetInfo
58
+ key: red_caps
59
+ subset: backpacking
60
+ rename_columns:
61
+ - ["caption", "text"]
62
+ split_key_mapping:
63
+ validation: train
multimodal/examples/flava/native/data.py ADDED
@@ -0,0 +1,560 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import os
8
+ from functools import partial
9
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
10
+
11
+ import torch
12
+ import torch.distributed as dist
13
+ import torchvision
14
+
15
+ from flava.data.transforms import (
16
+ default_image_pretraining_transforms,
17
+ default_text_transform,
18
+ default_torchvision_transforms,
19
+ encode_text_batch,
20
+ pad_batch,
21
+ TEXT_DEFAULT_TOKENIZER,
22
+ TEXT_WHOLE_WORD_MASK_TOKENIZER,
23
+ VL_MAX_LENGTH_DEFAULT,
24
+ VLTransform,
25
+ )
26
+ from flava.data.utils import build_datasets_from_info, fetch_images
27
+ from flava.definitions import HFDatasetInfo, TorchVisionDatasetInfo
28
+ from pytorch_lightning import LightningDataModule
29
+ from torch.utils.data.distributed import DistributedSampler
30
+ from transformers import (
31
+ BertTokenizer,
32
+ DataCollatorForLanguageModeling,
33
+ DataCollatorForWholeWordMask,
34
+ DefaultDataCollator,
35
+ TRANSFORMERS_CACHE,
36
+ )
37
+ from transformers.data.data_collator import torch_default_data_collator
38
+
39
+
40
+ def transform_image(transform, sample):
41
+ sample.update(transform(sample["image"]))
42
+ return sample
43
+
44
+
45
+ def get_sampler(dataset, shuffle=True):
46
+ if dist.is_initialized():
47
+ return DistributedSampler(dataset, shuffle=shuffle)
48
+ if shuffle:
49
+ return torch.utils.data.RandomSampler(dataset)
50
+ return torch.utils.data.SequentialSampler(dataset)
51
+
52
+
53
+ class DataCollatorForWholeWordMaskRetainingBatch(DataCollatorForWholeWordMask):
54
+ def torch_call(
55
+ self, examples: List[Union[List[int], Any, Dict[str, Any]]]
56
+ ) -> Dict[str, Any]:
57
+ masked_batch = super().torch_call(examples)
58
+ examples = torch_default_data_collator(examples)
59
+ examples["input_ids"] = masked_batch["input_ids"]
60
+ examples["labels"] = masked_batch["labels"]
61
+ return examples
62
+
63
+
64
+ class ImageDataModule(LightningDataModule):
65
+ def __init__(
66
+ self,
67
+ train_infos: List[HFDatasetInfo],
68
+ val_infos: Optional[List[HFDatasetInfo]] = None,
69
+ transforms: Optional[Tuple[Callable, Callable]] = None,
70
+ batch_size: int = 32,
71
+ num_workers: int = 4,
72
+ allow_uneven_batches: bool = False,
73
+ prefetch_factor: int = 2,
74
+ **kwargs: Any,
75
+ ):
76
+ super().__init__()
77
+ self.train_dataset_infos = train_infos
78
+ self.val_dataset_infos = val_infos
79
+ if self.val_dataset_infos is None:
80
+ self.val_dataset_infos = train_infos
81
+
82
+ self.batch_size = batch_size
83
+ self.num_workers = num_workers
84
+ self.allow_uneven_batches = allow_uneven_batches
85
+ self.prefetch_factor = prefetch_factor
86
+
87
+ if transforms is None:
88
+ transforms = default_image_pretraining_transforms()
89
+
90
+ self.train_transform, self.test_transform = transforms
91
+
92
+ def setup(self, stage=None):
93
+ train_transform = partial(transform_image, self.train_transform)
94
+ val_transform = partial(transform_image, self.test_transform)
95
+
96
+ self.train_dataset = build_datasets_from_info(
97
+ self.train_dataset_infos, split="train"
98
+ )
99
+ self.train_dataset.set_transform(train_transform)
100
+ self.val_dataset = build_datasets_from_info(
101
+ self.val_dataset_infos, split="validation"
102
+ )
103
+ self.val_dataset.set_transform(val_transform)
104
+
105
+ def train_dataloader(self):
106
+ return torch.utils.data.DataLoader(
107
+ self.train_dataset,
108
+ batch_size=self.batch_size,
109
+ num_workers=self.num_workers,
110
+ sampler=get_sampler(self.train_dataset, shuffle=True),
111
+ pin_memory=True,
112
+ persistent_workers=True,
113
+ prefetch_factor=self.prefetch_factor,
114
+ # uneven batches can cause distributed issues,
115
+ # drop last batch to prevent those.
116
+ # ideally, we don't need to drop these for unimodal cases
117
+ # but just to be safe
118
+ drop_last=True,
119
+ )
120
+
121
+ def val_dataloader(self):
122
+ return torch.utils.data.DataLoader(
123
+ self.val_dataset,
124
+ batch_size=self.batch_size,
125
+ num_workers=self.num_workers,
126
+ sampler=get_sampler(self.val_dataset, shuffle=False),
127
+ pin_memory=True,
128
+ persistent_workers=True,
129
+ prefetch_factor=self.prefetch_factor,
130
+ # uneven batches can cause distributed issues,
131
+ # drop last batch to prevent those.
132
+ # ideally, we don't need to drop these for unimodal cases
133
+ # but just to be safe
134
+ drop_last=True,
135
+ )
136
+
137
+ def test_dataloader(self):
138
+ return self.val_dataloader()
139
+
140
+ def on_before_batch_transfer(self, batch, *args):
141
+ if batch["label"].size(0) < self.batch_size and not self.allow_uneven_batches:
142
+ batch = pad_batch(batch, self.batch_size)
143
+ return batch
144
+
145
+
146
+ class TextDataModule(LightningDataModule):
147
+ def __init__(
148
+ self,
149
+ train_infos: List[HFDatasetInfo],
150
+ text_columns: List[str],
151
+ val_infos: Optional[List[HFDatasetInfo]] = None,
152
+ tokenizer: Optional[Callable] = None,
153
+ max_length: int = 512,
154
+ batch_size: int = 32,
155
+ num_workers: int = 4,
156
+ allow_uneven_batches: bool = False,
157
+ prefetch_factor: int = 2,
158
+ **kwargs: Any,
159
+ ):
160
+ super().__init__()
161
+ self.train_dataset_infos = train_infos
162
+ self.text_columns = text_columns
163
+ self.val_dataset_infos = val_infos
164
+ if self.val_dataset_infos is None:
165
+ self.val_dataset_infos = train_infos
166
+ self.tokenizer = tokenizer
167
+ self.max_length = max_length
168
+ self.batch_size = batch_size
169
+ self.num_workers = num_workers
170
+ self.allow_uneven_batches = allow_uneven_batches
171
+ self.prefetch_factor = prefetch_factor
172
+
173
+ def setup(self, stage=None):
174
+ if self.tokenizer is None:
175
+ self.tokenizer = BertTokenizer.from_pretrained(TEXT_DEFAULT_TOKENIZER)
176
+ transform = partial(
177
+ encode_text_batch,
178
+ tokenizer=self.tokenizer,
179
+ padding="max_length",
180
+ max_length=self.max_length,
181
+ truncation=True,
182
+ return_tensors="pt",
183
+ return_special_tokens_mask=True,
184
+ text_columns=self.text_columns,
185
+ return_batch=True,
186
+ )
187
+ self.train_dataset = build_datasets_from_info(
188
+ self.train_dataset_infos, split="train"
189
+ )
190
+ self.train_dataset.set_transform(transform)
191
+ self.val_dataset = build_datasets_from_info(
192
+ self.val_dataset_infos, split="validation"
193
+ )
194
+ self.val_dataset.set_transform(transform)
195
+
196
+ def train_dataloader(self):
197
+ return self._build_dataloader(self.train_dataset)
198
+
199
+ def val_dataloader(self):
200
+ return self._build_dataloader(self.val_dataset, shuffle=False)
201
+
202
+ def _build_dataloader(self, dataset, drop_last=False, shuffle=True):
203
+ return torch.utils.data.DataLoader(
204
+ dataset,
205
+ batch_size=self.batch_size,
206
+ num_workers=self.num_workers,
207
+ sampler=get_sampler(dataset, shuffle),
208
+ pin_memory=True,
209
+ persistent_workers=True,
210
+ prefetch_factor=self.prefetch_factor,
211
+ collate_fn=self._build_collator(),
212
+ drop_last=drop_last,
213
+ )
214
+
215
+ def _build_collator(self):
216
+ return DefaultDataCollator()
217
+
218
+ def on_before_batch_transfer(self, batch, *args):
219
+ batch.pop("token_type_ids", None)
220
+ mask = batch.pop("attention_mask", None)
221
+ if mask.size(0) < self.batch_size and not self.allow_uneven_batches:
222
+ batch = pad_batch(batch, self.batch_size)
223
+ return batch
224
+
225
+ def on_after_batch_transfer(self, batch, *args):
226
+ batch["text"] = batch.pop("input_ids")
227
+ return batch
228
+
229
+
230
+ class MLMDataModule(TextDataModule):
231
+ def __init__(
232
+ self,
233
+ train_infos: List[HFDatasetInfo],
234
+ text_columns: List[str],
235
+ val_infos: Optional[List[HFDatasetInfo]] = None,
236
+ mlm_probability: float = 0.15,
237
+ ignore_index: int = -1,
238
+ **kwargs: Any,
239
+ ):
240
+ super().__init__(train_infos, text_columns, val_infos, **kwargs)
241
+ self.mlm_probability = mlm_probability
242
+ self.ignore_index = ignore_index
243
+
244
+ def setup(self, stage=None):
245
+ if self.tokenizer is None:
246
+ self.tokenizer = BertTokenizer.from_pretrained(TEXT_DEFAULT_TOKENIZER)
247
+ transform = partial(
248
+ encode_text_batch,
249
+ tokenizer=self.tokenizer,
250
+ padding="max_length",
251
+ max_length=self.max_length,
252
+ truncation=True,
253
+ return_tensors="pt",
254
+ return_special_tokens_mask=True,
255
+ text_columns=self.text_columns,
256
+ return_batch=False,
257
+ )
258
+ self.train_dataset = build_datasets_from_info(
259
+ self.train_dataset_infos, split="train"
260
+ )
261
+ self.train_dataset.set_transform(transform)
262
+ self.val_dataset = build_datasets_from_info(
263
+ self.val_dataset_infos, split="validation"
264
+ )
265
+ self.val_dataset.set_transform(transform)
266
+
267
+ def _build_dataloader(self, dataset, drop_last=True, shuffle=True):
268
+ # uneven batches can cause distributed issues,
269
+ # drop last batch to prevent those.
270
+ # ideally, we don't need to drop these for unimodal cases
271
+ # but just to be safe
272
+ return super()._build_dataloader(dataset, drop_last=drop_last, shuffle=shuffle)
273
+
274
+ def _build_collator(self):
275
+ return DataCollatorForLanguageModeling(
276
+ self.tokenizer, mlm_probability=self.mlm_probability
277
+ )
278
+
279
+ def on_after_batch_transfer(self, batch, *args):
280
+ batch["text_masked"] = batch.pop("input_ids")
281
+ batch["mlm_labels"] = batch.pop("labels")
282
+ batch["mlm_labels"][batch["mlm_labels"] == -100] = self.ignore_index
283
+ return batch
284
+
285
+
286
+ class VLDataModule(LightningDataModule):
287
+ def __init__(
288
+ self,
289
+ train_infos: List[HFDatasetInfo],
290
+ val_infos: List[HFDatasetInfo],
291
+ text_transform: Optional[Callable] = None,
292
+ image_transforms: Optional[Tuple[Callable, Callable]] = None,
293
+ mlm_probablity: float = 0.15,
294
+ batch_size: int = 32,
295
+ num_workers: int = 4,
296
+ finetuning: bool = False,
297
+ ignore_index: int = -1,
298
+ itm_probability: float = 0.1,
299
+ allow_uneven_batches: bool = False,
300
+ fetch_num_threads: int = 4,
301
+ fetch_retries: int = 0,
302
+ fetch_sleep_timer: int = 0,
303
+ fetch_timeout: Optional[float] = None,
304
+ fetch_batch_size: int = 50,
305
+ prefetch_factor=2,
306
+ **kwargs,
307
+ ):
308
+ super().__init__()
309
+
310
+ self.train_dataset_infos = train_infos
311
+ self.val_dataset_infos = val_infos
312
+ if self.val_dataset_infos is None:
313
+ self.val_dataset_infos = train_infos
314
+ if image_transforms is None:
315
+ if not finetuning:
316
+ image_transforms = default_image_pretraining_transforms()
317
+ else:
318
+ image_transforms = default_torchvision_transforms(use_dict=True)
319
+
320
+ self.train_image_transform, self.test_image_transform = image_transforms
321
+ self.text_transform = text_transform
322
+ self.mlm_probability = mlm_probablity
323
+ self.batch_size = batch_size
324
+ self.num_workers = num_workers
325
+ self.ignore_index = ignore_index
326
+ self.itm_probability = itm_probability
327
+ self.allow_uneven_batches = allow_uneven_batches
328
+ self.fetch_num_threads = fetch_num_threads
329
+ self.fetch_retries = fetch_retries
330
+ self.fetch_sleep_timer = fetch_sleep_timer
331
+ self.fetch_timeout = fetch_timeout
332
+ self.fetch_batch_size = fetch_batch_size
333
+ self.prefetch_factor = prefetch_factor
334
+
335
+ def setup(self, stage=None):
336
+ if self.text_transform is None:
337
+ # TODO Update to use whole word mask vocab
338
+ text_tokenizer = BertTokenizer.from_pretrained(
339
+ TEXT_WHOLE_WORD_MASK_TOKENIZER
340
+ )
341
+ self.text_transform = default_text_transform(
342
+ text_tokenizer, max_text_length=VL_MAX_LENGTH_DEFAULT
343
+ )
344
+ self.text_tokenizer = self.text_transform.keywords["tokenizer"]
345
+ train_vl_transform = VLTransform(
346
+ self.train_image_transform, self.text_transform
347
+ )
348
+ val_vl_transform = VLTransform(self.test_image_transform, self.text_transform)
349
+
350
+ train_dataset = build_datasets_from_info(
351
+ self.train_dataset_infos, split="train"
352
+ )
353
+ train_dataset = train_dataset.map(
354
+ fetch_images,
355
+ batched=True,
356
+ batch_size=self.fetch_batch_size,
357
+ fn_kwargs={
358
+ "num_threads": self.fetch_num_threads,
359
+ "timeout": self.fetch_timeout,
360
+ "retries": self.fetch_retries,
361
+ "sleep_timer": self.fetch_sleep_timer,
362
+ },
363
+ )
364
+ train_dataset = train_dataset.filter(
365
+ lambda example: example["image"] is not None
366
+ )
367
+ self.train_dataset = train_dataset
368
+ self.train_dataset.set_transform(
369
+ partial(
370
+ train_vl_transform,
371
+ dataset=train_dataset.filter(lambda example: True),
372
+ itm_probability=self.itm_probability,
373
+ )
374
+ )
375
+
376
+ val_dataset = build_datasets_from_info(
377
+ self.val_dataset_infos, split="validation"
378
+ )
379
+
380
+ val_dataset = val_dataset.map(
381
+ fetch_images,
382
+ batched=True,
383
+ batch_size=self.fetch_batch_size,
384
+ fn_kwargs={
385
+ "num_threads": self.fetch_num_threads,
386
+ "timeout": self.fetch_timeout,
387
+ "retries": self.fetch_retries,
388
+ "sleep_timer": self.fetch_sleep_timer,
389
+ },
390
+ )
391
+ val_dataset = val_dataset.filter(lambda example: example["image"] is not None)
392
+ self.val_dataset = val_dataset
393
+ self.val_dataset.set_transform(
394
+ partial(
395
+ val_vl_transform,
396
+ dataset=self.val_dataset.filter(
397
+ lambda example: True
398
+ ), # Pass a copy to transform
399
+ itm_probability=self.itm_probability,
400
+ )
401
+ )
402
+
403
+ def train_dataloader(self):
404
+ return torch.utils.data.DataLoader(
405
+ self.train_dataset,
406
+ batch_size=self.batch_size,
407
+ num_workers=self.num_workers,
408
+ sampler=get_sampler(self.train_dataset),
409
+ collate_fn=self._build_collator(),
410
+ pin_memory=True,
411
+ persistent_workers=True,
412
+ prefetch_factor=self.prefetch_factor,
413
+ # uneven batches can cause distributed issues,
414
+ # drop last batch to prevent those.
415
+ drop_last=True,
416
+ )
417
+
418
+ def val_dataloader(self):
419
+ return torch.utils.data.DataLoader(
420
+ self.val_dataset,
421
+ batch_size=self.batch_size,
422
+ num_workers=self.num_workers,
423
+ sampler=get_sampler(self.val_dataset, shuffle=False),
424
+ collate_fn=self._build_collator(),
425
+ pin_memory=True,
426
+ persistent_workers=True,
427
+ prefetch_factor=self.prefetch_factor,
428
+ # uneven batches can cause distributed issues,
429
+ # drop last batch to prevent those.
430
+ drop_last=True,
431
+ )
432
+
433
+ def _build_collator(self):
434
+ return DataCollatorForWholeWordMaskRetainingBatch(
435
+ self.text_tokenizer, mlm_probability=self.mlm_probability
436
+ )
437
+
438
+ def on_before_batch_transfer(self, batch, *args):
439
+ batch.pop("token_type_ids", None)
440
+ mask = batch.pop("attention_mask", None)
441
+ if (
442
+ mask is not None
443
+ and mask.size(0) < self.batch_size
444
+ and not self.allow_uneven_batches
445
+ ):
446
+ batch = pad_batch(batch, self.batch_size)
447
+ return batch
448
+
449
+ def on_after_batch_transfer(self, batch, *args):
450
+ text_masked = batch.pop("input_ids")
451
+ mlm_labels = batch.pop("labels", None)
452
+ mlm_labels[mlm_labels == -100] = self.ignore_index
453
+ text = text_masked.detach().clone()
454
+ text[mlm_labels != -1] = mlm_labels[mlm_labels != -1]
455
+ batch.update(
456
+ {"mlm_labels": mlm_labels, "text": text, "text_masked": text_masked}
457
+ )
458
+ return batch
459
+
460
+
461
+ class TorchVisionDataModule(LightningDataModule):
462
+ def __init__(
463
+ self,
464
+ train_infos: List[TorchVisionDatasetInfo],
465
+ # Val info is not used for torchvision datamodule, but kept to keep things consistent
466
+ val_infos: Optional[List[TorchVisionDatasetInfo]] = None,
467
+ dataset_root: Optional[str] = None,
468
+ image_transforms: Optional[Tuple[Callable, Callable]] = None,
469
+ batch_size: int = 32,
470
+ num_workers: int = 4,
471
+ prefetch_factor: int = 2,
472
+ **kwargs: Any,
473
+ ):
474
+ super().__init__()
475
+ self.train_info = train_infos[0]
476
+ if val_infos is None:
477
+ val_infos = train_infos
478
+ self.val_info = val_infos[0]
479
+
480
+ self.train_class_ptr, self.train_root = self._parse_info(
481
+ self.train_info, dataset_root=dataset_root
482
+ )
483
+ self.val_class_ptr, self.val_root = self._parse_info(
484
+ self.val_info, dataset_root=dataset_root
485
+ )
486
+
487
+ if image_transforms is None:
488
+ image_transforms = default_torchvision_transforms()
489
+
490
+ self.train_transform, self.test_transform = image_transforms
491
+ self.batch_size = batch_size
492
+ self.num_workers = num_workers
493
+ self.prefetch_factor = prefetch_factor
494
+
495
+ def _parse_info(
496
+ self, info: TorchVisionDatasetInfo, dataset_root: Optional[str] = None
497
+ ):
498
+ assert hasattr(
499
+ torchvision.datasets, info.key
500
+ ), f"No dataset named {info.key} present in torchvision.datasets"
501
+ class_ptr = getattr(torchvision.datasets, info.key)
502
+ if dataset_root is None:
503
+ dataset_root = os.path.join(TRANSFORMERS_CACHE, "datasets", "torchvision")
504
+ dataset_root = os.path.join(dataset_root, class_ptr.__name__.lower())
505
+ os.makedirs(dataset_root, exist_ok=True)
506
+
507
+ return class_ptr, dataset_root
508
+
509
+ def setup(self, stage=None):
510
+ self.train_dataset = self.train_class_ptr(
511
+ self.train_root,
512
+ split=self.train_info.train_split,
513
+ transform=self.train_transform,
514
+ download=True,
515
+ )
516
+
517
+ if self.val_info.has_val:
518
+ self.val_dataset = self.val_class_ptr(
519
+ self.val_root,
520
+ split=self.val_info.val_split,
521
+ transform=self.test_transform,
522
+ download=True,
523
+ )
524
+
525
+ self.test_dataset = self.val_class_ptr(
526
+ self.val_root,
527
+ split=self.val_info.test_split,
528
+ transform=self.test_transform,
529
+ download=True,
530
+ )
531
+
532
+ def train_dataloader(self):
533
+ return self._build_dataloader(self.train_dataset)
534
+
535
+ def val_dataloader(self):
536
+ if self.val_info.has_val:
537
+ dataset = self.val_dataset
538
+ else:
539
+ dataset = self.test_dataset
540
+
541
+ return self._build_dataloader(dataset, shuffle=False)
542
+
543
+ def test_dataloader(self):
544
+ return self._build_dataloader(self.test_dataset, shuffle=False)
545
+
546
+ def _build_dataloader(self, dataset: torch.utils.data.Dataset, shuffle=True):
547
+ return torch.utils.data.DataLoader(
548
+ dataset,
549
+ sampler=get_sampler(dataset, shuffle),
550
+ batch_size=self.batch_size,
551
+ num_workers=self.num_workers,
552
+ pin_memory=True,
553
+ persistent_workers=True,
554
+ prefetch_factor=self.prefetch_factor,
555
+ )
556
+
557
+ def on_before_batch_transfer(self, batch, *args):
558
+ images, targets = batch
559
+ batch = {"image": images, "labels": targets}
560
+ return batch
multimodal/examples/flava/native/model.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Any, Tuple
8
+
9
+ import torch
10
+ from torch import nn
11
+ from torchmultimodal.models.flava.model import flava_model_for_pretraining
12
+ from transformers.optimization import get_cosine_schedule_with_warmup
13
+
14
+
15
+ def get_optimizer(
16
+ model: torch.nn.Module,
17
+ learning_rate: float = 0.0002,
18
+ adam_eps: float = 1.0e-08,
19
+ adam_weight_decay: float = 0.01,
20
+ adam_betas: Tuple[int, int] = (0.9, 0.999),
21
+ warmup_steps: int = 2000,
22
+ max_steps: int = 450000,
23
+ ):
24
+ optimizer = torch.optim.AdamW(
25
+ model.parameters(),
26
+ lr=learning_rate,
27
+ betas=adam_betas,
28
+ eps=adam_eps,
29
+ weight_decay=adam_weight_decay,
30
+ )
31
+ scheduler = get_cosine_schedule_with_warmup(
32
+ optimizer,
33
+ num_warmup_steps=warmup_steps,
34
+ num_training_steps=max_steps,
35
+ )
36
+ return optimizer, scheduler
37
+
38
+
39
+ class FLAVAPreTrainModule(nn.Module):
40
+ def __init__(
41
+ self,
42
+ use_bf16: bool = True,
43
+ **flava_pretraining_kwargs: Any,
44
+ ):
45
+ super().__init__()
46
+ self.model = flava_model_for_pretraining(**flava_pretraining_kwargs)
47
+ self.use_bf16 = use_bf16
48
+
49
+ def forward(self, batch, action=None):
50
+ # super hacky
51
+ if action == "encode_text":
52
+ return self.model.encode_text(batch)
53
+ elif action == "encode_image":
54
+ return self.model.encode_image(batch)
55
+
56
+ if "image" in batch and ("text" in batch or "text_masked" in batch):
57
+ required_embedding = "mm"
58
+ elif "image" in batch:
59
+ required_embedding = "image"
60
+ elif "text" in batch or "text_masked" in batch:
61
+ required_embedding = "text"
62
+ else:
63
+ raise RuntimeError("Batch needs to have either or both 'image' and 'text'.")
64
+
65
+ output = self.model(
66
+ image=batch.get("image"),
67
+ image_for_codebook=batch.get("image_for_codebook"),
68
+ image_patches_mask=batch.get("image_patches_mask"),
69
+ text=batch.get("text"),
70
+ text_masked=batch.get("text_masked"),
71
+ mlm_labels=batch.get("mlm_labels"),
72
+ itm_labels=batch.get("itm_labels"),
73
+ required_embedding=required_embedding,
74
+ )
75
+ return output
76
+
77
+ def encode_text(self, *args, **kwargs):
78
+ return self.model.encode_text(*args, **kwargs)
multimodal/examples/flava/native/train.py ADDED
@@ -0,0 +1,415 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # example command to train:
8
+ # `torchrun --nproc_per_node=8 -m flava.native.train config=flava/native/configs/pretrain_debug.yaml`
9
+
10
+
11
+ import time
12
+ from functools import partial
13
+ from typing import Any, Dict, Tuple, Union
14
+
15
+ import datasets
16
+ import numpy as np
17
+ import torch
18
+ import torch.distributed as dist
19
+ from common.data import MultiDataModule
20
+ from flava.definitions import FLAVAArguments
21
+ from flava.native.data import (
22
+ default_text_transform,
23
+ ImageDataModule,
24
+ MLMDataModule,
25
+ VL_MAX_LENGTH_DEFAULT,
26
+ VLDataModule,
27
+ )
28
+ from flava.native.model import FLAVAPreTrainModule, get_optimizer
29
+ from flava.native.utils import (
30
+ build_config,
31
+ enable_tf32,
32
+ get_model_parameters,
33
+ get_model_size_gb,
34
+ move_to_device,
35
+ print0,
36
+ run_imagenet_zero_shot,
37
+ set_seed,
38
+ setup_distributed_device,
39
+ )
40
+ from flava.utils import build_datamodule_kwargs
41
+
42
+ from omegaconf import DictConfig, OmegaConf
43
+ from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
44
+ apply_activation_checkpointing,
45
+ checkpoint_wrapper,
46
+ CheckpointImpl,
47
+ )
48
+ from torch.distributed.elastic.multiprocessing.errors import record
49
+ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, MixedPrecision
50
+ from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
51
+ from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
52
+ from torch.nn.parallel import DistributedDataParallel as DDP
53
+ from torch.utils.tensorboard import SummaryWriter
54
+ from torchmultimodal.models.flava.image_encoder import ImageTransformer
55
+ from torchmultimodal.models.flava.text_encoder import BERTTextEncoder
56
+ from torchmultimodal.models.flava.transformer import (
57
+ FLAVATransformerWithoutEmbeddings,
58
+ TransformerEncoderLayer,
59
+ )
60
+ from torchmultimodal.modules.losses.flava import FLAVAPretrainingLossOutput
61
+
62
+
63
+ def get_datamodules(config: FLAVAArguments) -> Tuple[MultiDataModule, ImageDataModule]:
64
+ datamodules = []
65
+
66
+ # also needed for the imagenet eval callback
67
+ imagenet_datamodule = ImageDataModule(
68
+ **build_datamodule_kwargs(config.datasets.image, config.training)
69
+ )
70
+ for dataset in config.datasets.selected:
71
+ if dataset == "image":
72
+ datamodules.append(imagenet_datamodule)
73
+ elif dataset == "text":
74
+ datamodules.append(
75
+ MLMDataModule(
76
+ **build_datamodule_kwargs(config.datasets.text, config.training)
77
+ )
78
+ )
79
+ elif dataset == "vl":
80
+ datamodules.append(
81
+ VLDataModule(
82
+ **build_datamodule_kwargs(config.datasets.vl, config.training)
83
+ )
84
+ )
85
+ else:
86
+ raise ValueError(f"unknown dataset: {dataset}")
87
+
88
+ return MultiDataModule(datamodules), imagenet_datamodule
89
+
90
+
91
+ @record
92
+ class Trainer:
93
+ def __init__(self, config: DictConfig):
94
+ if config.training.seed != -1:
95
+ set_seed(config.training.seed)
96
+
97
+ self.device: torch.device = setup_distributed_device()
98
+ self.config: DictConfig = config
99
+ self.rank: int = dist.get_rank()
100
+ self._logger: SummaryWriter = SummaryWriter(
101
+ f"logs/{config.training.strategy}/{int(time.time())}"
102
+ )
103
+ self.steps: int = -1
104
+ self.epochs: int = -1
105
+
106
+ multi_module, image_module = get_datamodules(config)
107
+
108
+ self.datamodule: MultiDataModule = multi_module
109
+ self.datamodule.setup("fit")
110
+
111
+ self.imagenet_val_dataloader = image_module.val_dataloader()
112
+ self.imagenet_val_text_transform = default_text_transform(
113
+ max_text_length=VL_MAX_LENGTH_DEFAULT
114
+ )
115
+
116
+ self.half_dtype = (
117
+ torch.bfloat16
118
+ if config.training.half_precision_format == "bfloat16"
119
+ else torch.float16
120
+ )
121
+
122
+ self.scaler = ShardedGradScaler() if config.training.enable_amp else None
123
+
124
+ def log(
125
+ self,
126
+ name: str,
127
+ value: Union[torch.Tensor, float, int],
128
+ log_rank_0: bool = True,
129
+ always_log: bool = False,
130
+ ):
131
+ if log_rank_0 and self.rank != 0:
132
+ return
133
+
134
+ if always_log or self.steps % self.config.training.log_interval == 0:
135
+ self._logger.add_scalar(name, value, self.steps)
136
+
137
+ def create_model(self) -> torch.nn.Module:
138
+ model_config = self.config.get("model", {})
139
+ print0(f"using model config: {model_config}")
140
+
141
+ model = FLAVAPreTrainModule(**model_config)
142
+ strategy = self.config.training.strategy
143
+
144
+ print0(
145
+ f"before {strategy} model parameters: {get_model_parameters(model):,}, "
146
+ f"size: {get_model_size_gb(model):.3} GB"
147
+ )
148
+
149
+ if self.config.training.activation_checkpointing:
150
+ check_fn = lambda submodule: isinstance(submodule, TransformerEncoderLayer)
151
+ checkpoint_impl = CheckpointImpl.REENTRANT
152
+
153
+ # DDP gradient hooks have compatibility issues with REENTRANT autograd
154
+ if strategy == "ddp":
155
+ checkpoint_impl = CheckpointImpl.NO_REENTRANT
156
+
157
+ checkpoint_wrapper_fn = partial(
158
+ checkpoint_wrapper,
159
+ offload_to_cpu=False,
160
+ checkpoint_impl=checkpoint_impl,
161
+ )
162
+ apply_activation_checkpointing(
163
+ model,
164
+ checkpoint_wrapper_fn=checkpoint_wrapper_fn,
165
+ check_fn=check_fn,
166
+ )
167
+
168
+ if strategy == "ddp":
169
+ # TODO do we have to do this in FSDP too? see https://github.com/pytorch/pytorch/issues/75478
170
+ model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
171
+ model = model.to(self.device)
172
+
173
+ print0(
174
+ f"after moving to cuda: {torch.cuda.memory_allocated()/1024**3:.3} GB"
175
+ )
176
+
177
+ model = DDP(
178
+ model,
179
+ device_ids=[self.rank],
180
+ find_unused_parameters=True,
181
+ gradient_as_bucket_view=True,
182
+ )
183
+ print0(f"after DDP: {torch.cuda.memory_allocated()/1024**3:.3} GB")
184
+ elif strategy == "fsdp":
185
+ mp = None
186
+ if self.config.training.enable_half_reduce_in_fsdp:
187
+ mp = MixedPrecision(
188
+ # param_dtype=self.half_dtype, not working
189
+ reduce_dtype=self.half_dtype,
190
+ # buffer_dtype=self.half_dtype,
191
+ )
192
+
193
+ model = FSDP(
194
+ model,
195
+ mixed_precision=mp,
196
+ device_id=self.device,
197
+ auto_wrap_policy=partial(
198
+ transformer_auto_wrap_policy,
199
+ transformer_layer_cls={
200
+ TransformerEncoderLayer,
201
+ ImageTransformer,
202
+ BERTTextEncoder,
203
+ FLAVATransformerWithoutEmbeddings,
204
+ },
205
+ ),
206
+ limit_all_gathers=True,
207
+ )
208
+
209
+ print0(f"after FSDP {torch.cuda.memory_allocated()/1024**3:.3} GB")
210
+
211
+ else:
212
+ raise ValueError(f"unknown strategy: {strategy}")
213
+
214
+ print0(
215
+ f"after {strategy} model parameters: {get_model_parameters(model):,}, "
216
+ f"size: {get_model_size_gb(model):.3} GB"
217
+ )
218
+
219
+ return model
220
+
221
+ def calculate_loss(
222
+ self, output: FLAVAPretrainingLossOutput, validation=False
223
+ ) -> torch.Tensor:
224
+ losses = output.losses
225
+
226
+ total_loss = 0
227
+ for key in losses:
228
+ if losses[key] is not None:
229
+ total_loss += losses[key]
230
+ loss_reduce = losses[key].detach()
231
+ dist.reduce(loss_reduce, dst=0)
232
+ if validation:
233
+ mode = "validation"
234
+ else:
235
+ mode = "train"
236
+ self.log(
237
+ f"{mode}/losses/{key}",
238
+ loss_reduce.item() / dist.get_world_size(),
239
+ )
240
+
241
+ return total_loss
242
+
243
+ def preprocess_data(self, data: Dict[str, Any]):
244
+ data = self.datamodule.on_before_batch_transfer(data, None)
245
+ data = move_to_device(data, self.device)
246
+ return self.datamodule.on_after_batch_transfer(data, None)
247
+
248
+ def _log_iteration_times(self, iteration_times):
249
+ profile_warmup_steps = config.get("profile_warmup_steps", 100)
250
+ start_idx = (
251
+ profile_warmup_steps
252
+ if profile_warmup_steps < self.config.training.max_steps
253
+ else 0
254
+ )
255
+ iteration_times = iteration_times[start_idx:]
256
+ avg_it_time = np.mean(iteration_times)
257
+ avg_throughput = (
258
+ config.training.batch_size * dist.get_world_size()
259
+ ) / avg_it_time
260
+ print0(f"Average over {len(iteration_times)} steps")
261
+ print0(f"Average iteration time {round(avg_it_time,4)}")
262
+ print0(f"Average throughput {round(avg_throughput,4)}")
263
+
264
+ def train(self) -> None:
265
+ print0(OmegaConf.to_container(self.config.training))
266
+ self.model = self.create_model()
267
+ model = self.model
268
+
269
+ optimizer, scheduler = get_optimizer(
270
+ model,
271
+ **self.config.training.optimizer,
272
+ )
273
+
274
+ iteration_times = []
275
+
276
+ while True:
277
+ t0 = time.time()
278
+ self.epochs += 1
279
+ dataloader = self.datamodule.train_dataloader()
280
+ dataloader.set_epoch(self.epochs)
281
+
282
+ for i, data in enumerate(dataloader):
283
+ torch.cuda.reset_peak_memory_stats()
284
+
285
+ self.steps += 1
286
+
287
+ if self.config.training.max_steps < self.steps:
288
+ if self.rank == 0:
289
+ self._log_iteration_times(iteration_times)
290
+ print0("Max steps reached, exiting")
291
+ return
292
+
293
+ model.train()
294
+ data = self.preprocess_data(data)
295
+ optimizer.zero_grad(set_to_none=True)
296
+
297
+ with torch.cuda.amp.autocast(
298
+ dtype=self.half_dtype, enabled=bool(self.scaler)
299
+ ):
300
+ output = model(data)
301
+ print0(
302
+ f"after forward pass {torch.cuda.memory_allocated()/1024**3:.3} GB"
303
+ )
304
+ self.log(
305
+ "stats/fwd memory alloc",
306
+ torch.cuda.memory_allocated() / 1024**3,
307
+ )
308
+ self.log(
309
+ "stats/fwd memory reserved",
310
+ torch.cuda.memory_reserved() / 1024**3,
311
+ )
312
+
313
+ total_loss = self.calculate_loss(output)
314
+
315
+ if self.scaler:
316
+ self.scaler.scale(total_loss).backward()
317
+ self.scaler.step(optimizer)
318
+ self.scaler.update()
319
+ else:
320
+ total_loss.backward()
321
+ optimizer.step()
322
+
323
+ scheduler.step()
324
+ torch.cuda.synchronize()
325
+ t1 = time.time()
326
+ batch_time = t1 - t0
327
+ batch_size = config.training.batch_size * dist.get_world_size()
328
+ items_time = batch_size / (t1 - t0)
329
+
330
+ t0 = t1
331
+ self.log("stats/sec per batch", batch_time)
332
+ self.log("stats/items per sec", items_time)
333
+
334
+ total_loss = total_loss.detach()
335
+ dist.reduce(total_loss, dst=0)
336
+
337
+ if self.rank == 0:
338
+ norm_total_loss = total_loss.item() / dist.get_world_size()
339
+
340
+ print(
341
+ f"epoch: {self.epochs} step {self.steps} loss: {norm_total_loss:.4}"
342
+ )
343
+ self.log("train/loss", norm_total_loss)
344
+ self.log("stats/batch_size", batch_size)
345
+
346
+ iteration_times.append(batch_time)
347
+
348
+ cuda_info = torch.cuda.memory_stats()
349
+ print("cuda alloc retries ", cuda_info.get("num_alloc_retries", 0))
350
+
351
+ self.log(
352
+ "stats/max_gpu_allocated_gb",
353
+ torch.cuda.max_memory_allocated() / 1024**3,
354
+ )
355
+ # TODO implement imagenet eval
356
+ # TODO implement checkpoint saving
357
+
358
+ self.validate()
359
+
360
+ def validate(self):
361
+ if self.steps % self.config.training.validation_steps != 0 or self.steps == 0:
362
+ return
363
+
364
+ model = self.model
365
+ model.eval()
366
+ print0("evaluating")
367
+
368
+ validation_loader = self.datamodule.val_dataloader()
369
+ validation_loss = torch.Tensor([0]).to(self.device)
370
+
371
+ for data in validation_loader:
372
+ data = self.preprocess_data(data)
373
+ with torch.no_grad():
374
+ with torch.cuda.amp.autocast(
375
+ dtype=self.half_dtype, enabled=bool(self.scaler)
376
+ ):
377
+ output = model(data)
378
+ total_loss = self.calculate_loss(output, validation=True)
379
+ validation_loss += total_loss.detach()
380
+
381
+ dist.reduce(validation_loss, dst=0)
382
+ norm_validation_loss = validation_loss.item() / dist.get_world_size()
383
+
384
+ print0(f"step {self.steps} EVAL loss: {norm_validation_loss:.4}")
385
+
386
+ def imagenet_validate(self):
387
+ print0("imagenet validation")
388
+ with torch.no_grad():
389
+ with torch.cuda.amp.autocast(
390
+ dtype=self.half_dtype, enabled=bool(self.scaler)
391
+ ):
392
+ metrics = run_imagenet_zero_shot(
393
+ self.model,
394
+ self.imagenet_val_dataloader,
395
+ self.device,
396
+ self.imagenet_val_text_transform,
397
+ )
398
+ if metrics is not None:
399
+ for key in metrics:
400
+ self.log(
401
+ f"val/imagenet/{key}",
402
+ metrics[key],
403
+ always_log=True,
404
+ )
405
+
406
+
407
+ if __name__ == "__main__":
408
+ datasets.logging.set_verbosity_error() # too spammy
409
+
410
+ config: FLAVAArguments = build_config()
411
+ if config.training.enable_tf32:
412
+ enable_tf32()
413
+
414
+ trainer = Trainer(config)
415
+ trainer.train()
multimodal/examples/flava/native/utils.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import os
8
+ import random
9
+ from typing import Any
10
+
11
+ import torch
12
+ from flava.data.imagenet_zeroshot_data import (
13
+ imagenet_classnames,
14
+ openai_imagenet_template,
15
+ )
16
+ from hydra.utils import instantiate
17
+ from omegaconf import DictConfig, OmegaConf
18
+ from torch import distributed as dist
19
+ from tqdm import tqdm
20
+
21
+ # optional syntax-highlighting for console output
22
+ try:
23
+ from rich.console import Console
24
+
25
+ c = Console(force_terminal=True)
26
+ print = c.log
27
+ except ImportError:
28
+ pass
29
+
30
+
31
+ def build_config() -> DictConfig:
32
+ cli_conf = OmegaConf.from_cli()
33
+ yaml_conf = OmegaConf.load(cli_conf.config)
34
+ conf = instantiate(yaml_conf)
35
+ conf = OmegaConf.merge(conf, cli_conf)
36
+ return conf
37
+
38
+
39
+ # TODO replace with tlc.copy_data_to_device
40
+ def move_to_device(obj: Any, device: torch.device) -> Any:
41
+ if isinstance(obj, dict):
42
+ d = {}
43
+ for k, v in obj.items():
44
+ d[k] = move_to_device(v, device)
45
+ return d
46
+ if isinstance(obj, list):
47
+ l = []
48
+ for v in obj:
49
+ l.append(move_to_device(v, device))
50
+ return l
51
+
52
+ return obj.to(device)
53
+
54
+
55
+ def get_model_size_gb(model: torch.nn.Module) -> int:
56
+ return sum(p.numel() * p.element_size() for p in model.parameters()) / (1024**3)
57
+
58
+
59
+ def get_model_parameters(model: torch.nn.Module) -> int:
60
+ return sum(p.numel() for p in model.parameters())
61
+
62
+
63
+ def set_seed(seed: int) -> None:
64
+ torch.manual_seed(seed)
65
+ random.seed(seed)
66
+
67
+
68
+ def setup_distributed_device() -> torch.device:
69
+ if not torch.cuda.is_available() or not dist.is_available():
70
+ return torch.device("cpu")
71
+
72
+ dist.init_process_group("nccl")
73
+ local_rank = int(os.environ["LOCAL_RANK"])
74
+ print("local rank", local_rank)
75
+ torch.cuda.set_device(local_rank)
76
+ return torch.device(f"cuda:{local_rank}")
77
+
78
+
79
+ def print0(*args, **kwargs) -> None:
80
+ if not dist.is_initialized() or dist.get_rank() == 0:
81
+ print(*args, **kwargs)
82
+
83
+
84
+ def enable_tf32() -> None:
85
+ torch.backends.cudnn.allow_tf32 = True
86
+ torch.backends.cuda.matmul.allow_tf32 = True
87
+
88
+
89
+ def rank0_only(func):
90
+ def wrapper(*args, **kwargs):
91
+ if not dist.is_initialized() or dist.get_rank() == 0:
92
+ return func(*args, **kwargs)
93
+
94
+ return wrapper
95
+
96
+
97
+ # zero shot classifier functions
98
+
99
+
100
+ def _zero_shot_classifier(model, device, text_transform, *args, **kwargs):
101
+ zeroshot_weights = []
102
+ for classname in tqdm(imagenet_classnames):
103
+ texts = text_transform(
104
+ [template(classname) for template in openai_imagenet_template]
105
+ )["input_ids"]
106
+ texts = texts.to(device)
107
+ class_embeddings = model(texts, action="encode_text")
108
+ class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
109
+ class_embedding = class_embeddings.mean(dim=0)
110
+ class_embedding /= class_embedding.norm()
111
+ zeroshot_weights.append(class_embedding)
112
+
113
+ zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(device)
114
+ return zeroshot_weights
115
+
116
+
117
+ def _accuracy(output, target, topk=(1,)):
118
+ pred = output.topk(max(topk), 1, True, True)[1].t()
119
+ correct = pred.eq(target.view(1, -1).expand_as(pred))
120
+ return [
121
+ float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy())
122
+ for k in topk
123
+ ]
124
+
125
+
126
+ def run_imagenet_zero_shot(model, dataloader, device, text_transform, *args, **kwargs):
127
+ print0("Starting ImageNet Zero-Shot Eval")
128
+ print0("Building classifier")
129
+ classifier = _zero_shot_classifier(model, device, text_transform)
130
+ print0("Classifier built")
131
+ top1, top5, n = 0.0, 0.0, 0.0
132
+ for i, sample in tqdm(enumerate(dataloader)):
133
+ images = sample["image"]
134
+ target = sample["label"]
135
+ images = images.to(device)
136
+ target = target.to(device)
137
+
138
+ # predict
139
+ # if hasattr(model, "module"):
140
+ # image_features = model.module.encode_image({"image": images})
141
+ # else:
142
+ image_features = model(images, action="encode_image")
143
+ image_features /= image_features.norm(dim=-1, keepdim=True)
144
+ logits = 100.0 * image_features @ classifier
145
+
146
+ # measure accuracy
147
+ acc1, acc5 = _accuracy(logits, target, topk=(1, 5))
148
+ top1 += acc1
149
+ top5 += acc5
150
+ n += images.size(0)
151
+ if i == 5:
152
+ break
153
+
154
+ top1 = top1 / n
155
+ top5 = top5 / n
156
+ results = {}
157
+ results["imagenet-zeroshot-val-top1"] = top1
158
+ results["imagenet-zeroshot-val-top5"] = top5
159
+ print0("results: ", results)
160
+ return results
multimodal/examples/flava/notebooks/RemapFLAVACheckpoint.ipynb ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "7cc982d1",
6
+ "metadata": {},
7
+ "source": [
8
+ "# Re-map FLAVA checkpoint\n",
9
+ "\n",
10
+ "Modifying FLAVA's components can cause existing model checkpoints to go out of sync with the updated architecture. This notebook shows how to load the existing checkpoint, re-map the old layers to the new layers, and save the new checkpoint.\n",
11
+ "\n",
12
+ "To upload a new checkpoint, you must have access to the PyTorch AWS S3 account, and manually upload it from a local copy."
13
+ ]
14
+ },
15
+ {
16
+ "cell_type": "markdown",
17
+ "id": "411e4191",
18
+ "metadata": {},
19
+ "source": [
20
+ "### Load original model\n",
21
+ "\n",
22
+ "Load the existing checkpoint into the FLAVA class to see what the architecture currently is."
23
+ ]
24
+ },
25
+ {
26
+ "cell_type": "code",
27
+ "execution_count": 3,
28
+ "id": "88ee917b",
29
+ "metadata": {},
30
+ "outputs": [],
31
+ "source": [
32
+ "import torch\n",
33
+ "from torchmultimodal.models.flava.model import flava_model_for_classification, flava_model_for_pretraining\n",
34
+ "\n",
35
+ "# flava_classification = flava_model_for_classification(num_classes=3)\n",
36
+ "flava_pretraining = flava_model_for_pretraining(pretrained_model_key='flava_full')"
37
+ ]
38
+ },
39
+ {
40
+ "cell_type": "markdown",
41
+ "id": "5f00b369",
42
+ "metadata": {},
43
+ "source": [
44
+ "### Print summary"
45
+ ]
46
+ },
47
+ {
48
+ "cell_type": "code",
49
+ "execution_count": null,
50
+ "id": "cc286394",
51
+ "metadata": {},
52
+ "outputs": [],
53
+ "source": [
54
+ "flava_pretraining"
55
+ ]
56
+ },
57
+ {
58
+ "cell_type": "markdown",
59
+ "id": "0d774455",
60
+ "metadata": {},
61
+ "source": [
62
+ "### Mapping function\n",
63
+ "\n",
64
+ "Replace this function with the code needed to map the old layer weights to the new layer weights."
65
+ ]
66
+ },
67
+ {
68
+ "cell_type": "code",
69
+ "execution_count": 4,
70
+ "id": "cc9e4537",
71
+ "metadata": {},
72
+ "outputs": [],
73
+ "source": [
74
+ "import re\n",
75
+ "\n",
76
+ "def map_state_dict(state_dict):\n",
77
+ " mapped_state_dict = {}\n",
78
+ " for param, val in state_dict.items():\n",
79
+ " res = re.search('attention.attention', param)\n",
80
+ " if res:\n",
81
+ " idx = res.start()\n",
82
+ " new_param = param[:idx] + param[idx+10:]\n",
83
+ " else:\n",
84
+ " new_param = param\n",
85
+ " mapped_state_dict[new_param] = val\n",
86
+ " return mapped_state_dict"
87
+ ]
88
+ },
89
+ {
90
+ "cell_type": "markdown",
91
+ "id": "29870590",
92
+ "metadata": {},
93
+ "source": [
94
+ "### Load old state dict"
95
+ ]
96
+ },
97
+ {
98
+ "cell_type": "code",
99
+ "execution_count": 5,
100
+ "id": "41f64d26",
101
+ "metadata": {},
102
+ "outputs": [],
103
+ "source": [
104
+ "# Load from url, replace this path if it changes\n",
105
+ "# old_model_url = 'https://download.pytorch.org/models/multimodal/flava/flava_model.pt'\n",
106
+ "# old_state_dict = torch.hub.load_state_dict_from_url(old_model_url)\n",
107
+ "\n",
108
+ "# Or get from loaded model\n",
109
+ "old_state_dict = flava_pretraining.model.state_dict()"
110
+ ]
111
+ },
112
+ {
113
+ "cell_type": "markdown",
114
+ "id": "75322113",
115
+ "metadata": {},
116
+ "source": [
117
+ "### Perform re-mapping"
118
+ ]
119
+ },
120
+ {
121
+ "cell_type": "code",
122
+ "execution_count": 6,
123
+ "id": "17363ae8",
124
+ "metadata": {},
125
+ "outputs": [],
126
+ "source": [
127
+ "#new_state_dict = map_state_dict(old_state_dict)\n",
128
+ "new_state_dict = old_state_dict"
129
+ ]
130
+ },
131
+ {
132
+ "cell_type": "markdown",
133
+ "id": "d94c4133",
134
+ "metadata": {},
135
+ "source": [
136
+ "### Save updated checkpoint"
137
+ ]
138
+ },
139
+ {
140
+ "cell_type": "code",
141
+ "execution_count": 7,
142
+ "id": "bc6baad9",
143
+ "metadata": {},
144
+ "outputs": [],
145
+ "source": [
146
+ "save_path = '/Users/rafiayub/flava_model.pt'\n",
147
+ "torch.save(new_state_dict, save_path)"
148
+ ]
149
+ }
150
+ ],
151
+ "metadata": {
152
+ "kernelspec": {
153
+ "display_name": "Python 3 (ipykernel)",
154
+ "language": "python",
155
+ "name": "python3"
156
+ },
157
+ "language_info": {
158
+ "codemirror_mode": {
159
+ "name": "ipython",
160
+ "version": 3
161
+ },
162
+ "file_extension": ".py",
163
+ "mimetype": "text/x-python",
164
+ "name": "python",
165
+ "nbconvert_exporter": "python",
166
+ "pygments_lexer": "ipython3",
167
+ "version": "3.9.12"
168
+ }
169
+ },
170
+ "nbformat": 4,
171
+ "nbformat_minor": 5
172
+ }
multimodal/examples/flava/tools/convert_weights.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import argparse
8
+
9
+ import torch
10
+ from torchmultimodal.models.flava.model import flava_model_for_pretraining
11
+
12
+ KEY_REPLACEMENTS = {
13
+ "image_encoder.module": "image_encoder",
14
+ "text_encoder.module": "text_encoder",
15
+ "mm_encoder.module": "mm_encoder",
16
+ "mm_encoder.encoder.cls_token": "mm_encoder.cls_token",
17
+ "mm_image_projection": "image_to_mm_projection",
18
+ "mm_text_projection": "text_to_mm_projection",
19
+ "model.heads.cmd.mim_head": "loss.mmm_loss.mim",
20
+ "model.heads.cmd.mlm_head": "loss.mmm_loss.mlm",
21
+ "model.heads.fairseq_mlm": "loss.mlm_loss",
22
+ "model.heads.imagenet.mim_head": "loss.mim_loss",
23
+ "cls.predictions.transform": "cls",
24
+ "cls.predictions": "cls",
25
+ "cls.LayerNorm": "cls.layer_norm",
26
+ "model.text_projection": "loss.contrastive_loss.text_projection",
27
+ "model.image_projection": "loss.contrastive_loss.image_projection",
28
+ "model.heads.cmd.clip_head.logit_scale": "loss.contrastive_loss.logit_scale",
29
+ "model.heads.cmd.itm_head": "loss.itm_loss",
30
+ "intermediate.dense": "intermediate",
31
+ "output.dense": "output",
32
+ }
33
+
34
+
35
+ def convert_weights(args):
36
+ ckpt = torch.load(args.ckpt_file, map_location="cpu")
37
+ flava = flava_model_for_pretraining()
38
+ model = ckpt["model"]
39
+ import pdb
40
+
41
+ pdb.set_trace()
42
+ for key in list(model.keys()):
43
+ original = key
44
+ for option, replacement in KEY_REPLACEMENTS.items():
45
+ key = key.replace(option, replacement)
46
+ model[key] = model.pop(original)
47
+
48
+ if args.add_codebook:
49
+ # Since codebook is anyways not trained in FLAVA pretraining
50
+ # we can use the pretrained one that we get from FLAVA initialized
51
+ # model
52
+ model.update(
53
+ {
54
+ f"image_codebook.{key}": value
55
+ for key, value in flava.image_codebook.state_dict().items()
56
+ }
57
+ )
58
+ flava.load_state_dict(model)
59
+
60
+ # Let's save the model now.
61
+ torch.save(flava.state_dict(), args.save_file)
62
+
63
+
64
+ if __name__ == "__main__":
65
+ parser = argparse.ArgumentParser(description="Convert weights")
66
+ parser.add_argument("ckpt_file", type=str)
67
+ parser.add_argument("save_file", type=str)
68
+ parser.add_argument("--add_codebook", action="store_true")
69
+
70
+ args = parser.parse_args()
71
+
72
+ convert_weights(args)
multimodal/examples/mugen/data/README.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ This folder contains code for interfacing the [MUGEN dataset](https://mugen-org.github.io). The MUGEN dataset contains over 300k videos, each with corresponding audio and text, from the game CoinRun.
2
+
3
+ Before using this code,
4
+
5
+ 1. Download the 3.2s-video dataset [here](https://mugen-org.github.io/download) and save as `datasets/coinrun` in your working directory.
6
+ * In each of `datasets/coinrun/coinrun_dataset_jsons/release/{train/val/test}.json`, change the value of `json_object["metadata"]["data_folder"]` to the absolute path of `datasets/coinrun`, e.g. `"/path/to/datasets/coinrun/"`.
7
+ 2. Download the MUGEN dataset assets [here](https://github.com/mugen-org/MUGEN_baseline/tree/main/lib/data/coinrun/assets) and save under `datasets/coinrun` as `datasets/coinrun/assets` in your pwd.
8
+ * Downloading the assets from GitHub requires `git clone`-ing the original MUGEN repo and copying the assets directory located at `MUGEN_baseline/lib/data/coinrun/assets`.
9
+
10
+ Note: saving the dataset and assets to locations other than those listed above requires passing custom arguments to `MUGENDataModuleBase` or `MUGENDataset` through `MUGENDatasetArgs.data_path` and `MUGENDatasetArgs.asset_path`, respectively.
multimodal/examples/mugen/data/coinrun/construct_from_json.py ADDED
@@ -0,0 +1,756 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+ import os
9
+
10
+ import numpy as np
11
+ from PIL import Image
12
+
13
+ DEATH_ANIM_LENGTH = 30
14
+ FINISHED_LEVEL_ANIM_LENGTH = 20
15
+ MONSTER_DEATH_ANIM_LENGTH = 3
16
+ SPACE = "."
17
+ LADDER = "="
18
+ LAVA_SURFACE = "^"
19
+ LAVA_MIDDLE = "|"
20
+ WALL_SURFACE = "S"
21
+ WALL_MIDDLE = "A"
22
+ WALL_CLIFF_LEFT = "a"
23
+ WALL_CLIFF_RIGHT = "b"
24
+ COIN_OBJ1 = "1"
25
+ COIN_OBJ2 = "2"
26
+ CRATE_NORMAL = "#"
27
+ CRATE_DOUBLE = "$"
28
+ CRATE_SINGLE = "&"
29
+ CRATE_WARNING = "%"
30
+
31
+
32
+ def define_semantic_color_map(max_label=18):
33
+ assert max_label in [18, 21, 22], f"max_label {max_label} is not supported!"
34
+
35
+ semantic_color_map = {}
36
+
37
+ semantic_color_map["background"] = 0
38
+
39
+ # alien is always set to max_label (assumes it always appear in a video)
40
+ semantic_color_map["alien"] = max_label
41
+
42
+ if max_label == 18:
43
+ semantic_color_map["world"] = {
44
+ WALL_MIDDLE: 3,
45
+ WALL_SURFACE: 4,
46
+ WALL_CLIFF_LEFT: 5,
47
+ WALL_CLIFF_RIGHT: 6,
48
+ COIN_OBJ1: 17,
49
+ COIN_OBJ2: 0,
50
+ CRATE_NORMAL: 8,
51
+ CRATE_DOUBLE: 8,
52
+ CRATE_SINGLE: 8,
53
+ CRATE_WARNING: 8,
54
+ LAVA_MIDDLE: 1,
55
+ LAVA_SURFACE: 2,
56
+ LADDER: 7,
57
+ }
58
+
59
+ semantic_color_map["shield"] = 0
60
+
61
+ semantic_color_map["monster"] = {
62
+ "sawHalf": 16,
63
+ "bee": 15,
64
+ "slimeBlock": 14,
65
+ "slimeBlue": 13,
66
+ "mouse": 12,
67
+ "snail": 11,
68
+ "ladybug": 10,
69
+ "wormPink": 9,
70
+ "barnacle": 0,
71
+ "frog": 0,
72
+ }
73
+ else:
74
+ semantic_color_map["world"] = {
75
+ WALL_MIDDLE: 3,
76
+ WALL_SURFACE: 4,
77
+ WALL_CLIFF_LEFT: 5,
78
+ WALL_CLIFF_RIGHT: 6,
79
+ COIN_OBJ1: 19,
80
+ COIN_OBJ2: 20,
81
+ CRATE_NORMAL: 8,
82
+ CRATE_DOUBLE: 8,
83
+ CRATE_SINGLE: 8,
84
+ CRATE_WARNING: 8,
85
+ LAVA_MIDDLE: 1,
86
+ LAVA_SURFACE: 2,
87
+ LADDER: 7,
88
+ }
89
+
90
+ semantic_color_map["shield"] = 21
91
+
92
+ semantic_color_map["monster"] = {
93
+ "sawHalf": 16,
94
+ "bee": 15,
95
+ "slimeBlock": 14,
96
+ "slimeBlue": 13,
97
+ "mouse": 12,
98
+ "snail": 11,
99
+ "ladybug": 10,
100
+ "wormPink": 9,
101
+ "barnacle": 17,
102
+ "frog": 18,
103
+ }
104
+
105
+ return semantic_color_map
106
+
107
+
108
+ def generate_asset_paths(game):
109
+ # use background corresponding with ground theme
110
+ bgtheme = game.background_themes[game.world_theme_n]
111
+
112
+ gtheme = game.ground_themes[game.world_theme_n]
113
+ walls = "kenney/Ground/" + gtheme + "/" + gtheme.lower()
114
+
115
+ # default option with fixed agent look
116
+ atheme = game.agent_themes[game.agent_theme_n]
117
+ alien = "kenneyLarge/Players/128x256_no_helmet/" + atheme + "/alien" + atheme
118
+ alien_paths = {"Mugen": alien}
119
+
120
+ tiles = "kenney/Tiles/"
121
+ items = "kenneyLarge/Items/"
122
+ enemy = "kenneyLarge/Enemies/"
123
+
124
+ asset_files = {}
125
+
126
+ asset_files["background"] = bgtheme
127
+
128
+ asset_files["world"] = {
129
+ WALL_MIDDLE: walls + "Center.png",
130
+ WALL_SURFACE: walls + "Mid.png",
131
+ WALL_CLIFF_LEFT: walls + "Cliff_left.png",
132
+ WALL_CLIFF_RIGHT: walls + "Cliff_right.png",
133
+ COIN_OBJ1: items + "coinGold.png",
134
+ COIN_OBJ2: items + "gemRed.png",
135
+ CRATE_NORMAL: tiles + "boxCrate.png",
136
+ CRATE_DOUBLE: tiles + "boxCrate_double.png",
137
+ CRATE_SINGLE: tiles + "boxCrate_single.png",
138
+ CRATE_WARNING: tiles + "boxCrate_warning.png",
139
+ LAVA_MIDDLE: tiles + "lava.png",
140
+ LAVA_SURFACE: tiles + "lavaTop_low.png",
141
+ LADDER: tiles + "ladderMid.png",
142
+ }
143
+
144
+ asset_files["alien"] = {}
145
+ for alien_name in alien_paths.keys():
146
+ asset_files["alien"][alien_name] = {
147
+ "walk1": alien_paths[alien_name] + "_walk1.png",
148
+ "walk2": alien_paths[alien_name] + "_walk2.png",
149
+ "climb1": alien_paths[alien_name] + "_climb1.png",
150
+ "climb2": alien_paths[alien_name] + "_climb2.png",
151
+ "stand": alien_paths[alien_name] + "_stand.png",
152
+ "jump": alien_paths[alien_name] + "_jump.png",
153
+ "duck": alien_paths[alien_name] + "_duck.png",
154
+ "hit": alien_paths[alien_name] + "_hit.png",
155
+ }
156
+ asset_files["shield"] = "bubble_shield.png"
157
+
158
+ game.flatten_monster_names()
159
+ # monster assets are generated based on list of names used at rendering
160
+ asset_files["monster"] = {
161
+ name: enemy + name + ".png" for name in game.flattened_monster_names
162
+ }
163
+
164
+ return asset_files
165
+
166
+
167
+ # binarize alpha channel if input img is in RGBA mode, set anything above 0 to 255
168
+ def binarize_alpha_channel(img):
169
+ if img.mode != "RGBA":
170
+ return img
171
+
172
+ w, h = img.size
173
+ for i in range(w):
174
+ for j in range(h):
175
+ pixel = img.getpixel((i, j))
176
+
177
+ # set alpha to 255 if alpha > 0
178
+ if pixel[3] > 0:
179
+ img.putpixel((i, j), (pixel[0], pixel[1], pixel[2], 255))
180
+
181
+ return img
182
+
183
+
184
+ class Asset:
185
+ def __init__(
186
+ self,
187
+ name,
188
+ file,
189
+ asset_root,
190
+ kind="world",
191
+ kx=80,
192
+ ky=80,
193
+ semantic_color=(0, 0, 0),
194
+ flip=False,
195
+ binarize_alpha=False,
196
+ ):
197
+ self.name = name
198
+ self.file = file
199
+ self.asset_root = asset_root
200
+ self.kind = kind
201
+ self.kx = kx
202
+ self.ky = ky
203
+ self.semantic_color = semantic_color
204
+ self.flip = flip
205
+ self.binarize_alpha = binarize_alpha
206
+
207
+ self.load_asset()
208
+
209
+ def load_asset(self):
210
+ asset_path = os.path.join(self.asset_root, self.file)
211
+ if not os.path.isfile(asset_path):
212
+ # basically remove the '_walk1' postfix
213
+ fallback_path = (
214
+ "_".join(asset_path.split("_")[:-1]) + "." + asset_path.split(".")[-1]
215
+ )
216
+ assert os.path.isfile(fallback_path), asset_path
217
+ asset_path = fallback_path
218
+ self.asset = Image.open(asset_path)
219
+
220
+ # used for (user control) asset swap, because alien h:w == 2:1 while others is 1:1
221
+ # the asset resize at loading and render grid size all need to change respectively
222
+ self.aspect_ratio = self.asset.size[1] / self.asset.size[0]
223
+
224
+ if self.kind == "world":
225
+ if self.name != LAVA_MIDDLE and self.name != LAVA_SURFACE:
226
+ # LAVA has a special way of rendering animation so don't resize now
227
+ self.asset = self.asset.resize(
228
+ (math.ceil(self.kx + 0.5), math.ceil(self.ky + 0.5))
229
+ )
230
+ elif self.kind == "alien":
231
+ self.asset = self.asset.resize(
232
+ (math.ceil(self.kx), math.ceil(self.aspect_ratio * self.ky))
233
+ )
234
+ elif self.kind == "shield":
235
+ self.asset = self.asset.resize(
236
+ (math.ceil(self.kx * 1.15), math.ceil(self.ky * 2.1))
237
+ )
238
+ elif self.kind == "monster" or self.kind == "background":
239
+ self.asset = self.asset.resize((math.ceil(self.kx), math.ceil(self.ky)))
240
+ else:
241
+ raise NotImplementedError(f"Unknown asset kind {self.kind}")
242
+
243
+ # flip if needed (for facing left/right)
244
+ if self.flip:
245
+ self.asset = self.asset.transpose(Image.FLIP_LEFT_RIGHT)
246
+
247
+ if self.binarize_alpha:
248
+ self.asset = binarize_alpha_channel(self.asset)
249
+
250
+
251
+ def load_assets(
252
+ asset_files, asset_root, semantic_color_map, kx=80, ky=80, gen_original=False
253
+ ):
254
+ asset_map = {}
255
+
256
+ for kind in asset_files.keys():
257
+ assert kind in semantic_color_map
258
+
259
+ if kind == "background":
260
+ # background will be loaded separately
261
+ continue
262
+
263
+ if kind == "shield":
264
+ # asset file for the bubble shield in agent power-up mode
265
+ asset_map[kind] = Asset(
266
+ name=kind,
267
+ file=asset_files[kind],
268
+ asset_root=asset_root,
269
+ kind=kind,
270
+ kx=kx,
271
+ ky=ky,
272
+ semantic_color=semantic_color_map[kind],
273
+ binarize_alpha=not gen_original,
274
+ )
275
+ continue
276
+
277
+ for key in asset_files[kind].keys():
278
+ if kind == "world":
279
+ # ground asset, no need to worry about pose or facing
280
+ asset_map[key] = Asset(
281
+ name=key,
282
+ file=asset_files[kind][key],
283
+ asset_root=asset_root,
284
+ kind=kind,
285
+ kx=kx,
286
+ ky=ky,
287
+ semantic_color=semantic_color_map[kind][key],
288
+ binarize_alpha=not gen_original,
289
+ )
290
+ elif kind == "alien":
291
+ for pose in asset_files[kind][key].keys():
292
+ # facing right is default to empty
293
+ all_facings = ["", "_left"]
294
+ for facing in all_facings:
295
+ a_key = key + "_" + pose + facing
296
+
297
+ asset_map[a_key] = Asset(
298
+ name=a_key,
299
+ file=asset_files[kind][key][pose],
300
+ asset_root=asset_root,
301
+ kind=kind,
302
+ kx=kx,
303
+ ky=ky,
304
+ semantic_color=semantic_color_map[kind],
305
+ flip=(facing != ""), # flip the asset if facing is not ''
306
+ binarize_alpha=not gen_original,
307
+ )
308
+ elif kind == "monster":
309
+ # for monsters, 3 types of assets will be loaded
310
+ # for each of them, facing can be left or right
311
+ all_poses = ["", "_move", "_dead"] # walk1 is default to empty
312
+ all_facings = ["", "_right"] # facing left is default to empty
313
+ base_fn = os.path.splitext(asset_files[kind][key])[
314
+ 0
315
+ ] # e.g. Enemies/bee
316
+ for pose in all_poses:
317
+ for facing in all_facings:
318
+ m_key = key + pose + facing
319
+ file_name = base_fn + pose + ".png"
320
+
321
+ asset_map[m_key] = Asset(
322
+ name=m_key,
323
+ file=file_name,
324
+ asset_root=asset_root,
325
+ kind="monster",
326
+ kx=kx,
327
+ ky=ky,
328
+ semantic_color=semantic_color_map[kind][key],
329
+ flip=(facing != ""), # flip the asset if facing is not ''
330
+ binarize_alpha=not gen_original,
331
+ )
332
+ else:
333
+ raise NotImplementedError(f"Unknown asset kind {kind}")
334
+
335
+ return asset_map
336
+
337
+
338
+ # load background asset, zoom is different so need a separate function
339
+ def load_bg_asset(asset_files, asset_root, semantic_color_map, zx, zy):
340
+ kind = "background"
341
+ bg_asset = Asset(
342
+ name=kind,
343
+ file=asset_files[kind],
344
+ asset_root=asset_root,
345
+ kind=kind,
346
+ kx=zx,
347
+ ky=zy,
348
+ semantic_color=semantic_color_map[kind],
349
+ )
350
+ return bg_asset
351
+
352
+
353
+ # used for alien dying animation in gen_original mode
354
+ def get_transparent_asset(input_asset, transparency):
355
+ assert input_asset.mode == "RGBA"
356
+ np_asset = np.array(input_asset, dtype=np.int16)
357
+ np_asset[:, :, 3] -= transparency
358
+ np_asset[:, :, 3] = np.clip(np_asset[:, :, 3], 0, None)
359
+ return Image.fromarray(np_asset.astype(np.uint8))
360
+
361
+
362
+ # return rect in integer values, floor for x1,y1, ceil for x2,y2 or w,h
363
+ def integer_rect(rect):
364
+ return [
365
+ math.floor(rect[0]),
366
+ math.floor(rect[1]),
367
+ math.ceil(rect[2]),
368
+ math.ceil(rect[3]),
369
+ ]
370
+
371
+
372
+ def convert_xywh_to_xyxy(rect):
373
+ return [rect[0], rect[1], rect[0] + rect[2], rect[1] + rect[3]]
374
+
375
+
376
+ def convert_xyxy_to_xywh(rect):
377
+ return [rect[0], rect[1], rect[2] - rect[0], rect[3] - rect[1]]
378
+
379
+
380
+ # rect format is xywh, img_size is (w,h)
381
+ def check_out_of_bounds(rect, img_size):
382
+ if rect[0] + rect[2] < 0:
383
+ return True
384
+ if rect[0] > img_size[0]:
385
+ return True
386
+ if rect[1] + rect[3] < 0:
387
+ return True
388
+ if rect[1] > img_size[1]:
389
+ return True
390
+ return False
391
+
392
+
393
+ # return intersect of two rects, input and output are both in xywh format
394
+ def intersect_rects(rect1, rect2):
395
+ xyxy_rect1 = convert_xywh_to_xyxy(rect1)
396
+ xyxy_rect2 = convert_xywh_to_xyxy(rect2)
397
+ xyxy_res_rect = [
398
+ max(xyxy_rect1[0], xyxy_rect2[0]),
399
+ max(xyxy_rect1[1], xyxy_rect2[1]),
400
+ min(xyxy_rect1[2], xyxy_rect2[2]),
401
+ min(xyxy_rect1[3], xyxy_rect2[3]),
402
+ ]
403
+
404
+ xywh_res_rect = convert_xyxy_to_xywh(xyxy_res_rect)
405
+
406
+ # check if the intersection is empty
407
+ if xywh_res_rect[2] > 0 and xywh_res_rect[3] > 0:
408
+ return xywh_res_rect
409
+ else:
410
+ return None
411
+
412
+
413
+ # rect is in the format of xywh
414
+ def paint_color_in_rect_with_mask(
415
+ img, rect, color, mask, gen_original=False, ignore_mask=False, cut_mask_top_ratio=0
416
+ ):
417
+ w, h = mask.size
418
+ img_w, img_h = img.size
419
+ # in some cases, mask size doesn't match the rect (e.g. monster dying)
420
+ if rect[2] != w or rect[3] != h:
421
+ if not gen_original:
422
+ mask = mask.resize((rect[2], rect[3]), resample=Image.NEAREST)
423
+ else:
424
+ mask = mask.resize((rect[2], rect[3]))
425
+ w, h = mask.size
426
+
427
+ if not gen_original:
428
+ # generate semantic map
429
+ if ignore_mask and cut_mask_top_ratio != 0:
430
+ # specifically for agent because its asset has a large empty area in the top,
431
+ # we don't want it to be fully masked
432
+ if cut_mask_top_ratio < 0:
433
+ # automatic calculate the first non-empty row from top
434
+ np_mask = np.array(mask)
435
+ cut_mask_top_rows = (np_mask.T[0].sum(axis=0) != 0).argmax(axis=0)
436
+ else:
437
+ cut_mask_top_rows = int(cut_mask_top_ratio * rect[2])
438
+ rect[1] += cut_mask_top_rows
439
+ rect[3] = mask.size[1] - cut_mask_top_rows
440
+
441
+ img = img.paste(color, convert_xywh_to_xyxy(rect))
442
+ else:
443
+ # paste in single color if generating semantic maps (so not original)
444
+ # if ignore_mask, this will generate a complete block mask same as rect
445
+ img = img.paste(
446
+ color,
447
+ convert_xywh_to_xyxy(rect),
448
+ mask if (mask.mode == "RGBA" and not ignore_mask) else None,
449
+ )
450
+ else:
451
+ # generate rgb data
452
+ img = img.paste(
453
+ mask, convert_xywh_to_xyxy(rect), mask if mask.mode == "RGBA" else None
454
+ )
455
+
456
+ return
457
+
458
+
459
+ def draw_game_frame(
460
+ game,
461
+ frame_id,
462
+ asset_map,
463
+ kx,
464
+ ky,
465
+ gen_original=False,
466
+ bbox_smap_for_agent=False,
467
+ bbox_smap_for_monsters=False,
468
+ alien_name=None,
469
+ skip_foreground=False,
470
+ skip_background=False,
471
+ skip_mugen=False,
472
+ only_mugen=False,
473
+ ):
474
+ # set default alien name/key
475
+ if alien_name is None:
476
+ alien_name = "Mugen"
477
+
478
+ # initialize an empty image (all zero, for background)
479
+ if not gen_original:
480
+ img = Image.new("L", (game.video_res, game.video_res))
481
+ else:
482
+ img = Image.new("RGB", (game.video_res, game.video_res))
483
+
484
+ video_center = (game.video_res - 1) // 2
485
+
486
+ frame = game.frames[frame_id]
487
+
488
+ # for agent-centric
489
+ # dx = -frame.agent.x * kx + video_center - 0.5 * kx
490
+ # dy = frame.agent.y * ky - video_center - 0.5 * ky
491
+ # for video data (no vertical camera move)
492
+ dx = -frame.agent.x * kx + video_center - 0.5 * kx
493
+
494
+ # different dy/ky ratio based on zoom level, to adjust camera view
495
+ if game.zoom == 5.5:
496
+ dy_ratio = 5.0
497
+ elif game.zoom == 4.3:
498
+ dy_ratio = 6.5
499
+ elif game.zoom == 5.0:
500
+ dy_ratio = 5.5
501
+ elif game.zoom == 6.0:
502
+ dy_ratio = 4.5
503
+ else:
504
+ raise NotImplementedError(f"zoom level {game.zoom} is not supported!")
505
+ dy = -video_center + dy_ratio * ky
506
+
507
+ # update background image with proper zoom for gen_original mode
508
+ # NOTE: if desired background label is not zero, set it here to asset_map['background'].semantic_color
509
+ if gen_original and not skip_background and not only_mugen:
510
+ zx = game.video_res * game.zoom
511
+ zy = zx
512
+ for tile_x in range(-1, 3):
513
+ for tile_y in range(-1, 2):
514
+ bg_rect = [0, 0, zx, zy]
515
+ bg_rect[0] = (
516
+ zx * tile_x
517
+ + video_center
518
+ + game.bgzoom * (dx + kx * game.maze_h / 2)
519
+ - zx * 0.5
520
+ )
521
+ bg_rect[1] = (
522
+ zy * tile_y
523
+ + video_center
524
+ + game.bgzoom * (dy - ky * game.maze_h / 2)
525
+ - zy * 0.5
526
+ )
527
+ if check_out_of_bounds(bg_rect, img.size):
528
+ continue
529
+ img.paste(
530
+ asset_map["background"].asset,
531
+ convert_xywh_to_xyxy(integer_rect(bg_rect)),
532
+ )
533
+
534
+ # NOTE: game engine now hard-code 64 for maze_size
535
+ radius = int(1 + game.maze_w / game.zoom)
536
+ ix = int(frame.agent.x + 0.5)
537
+ iy = int(frame.agent.y + 0.5)
538
+ x_start = max(ix - radius, 0)
539
+ x_end = min(ix + radius + 1, game.maze_w)
540
+ y_start = max(iy - radius, 0)
541
+ y_end = min(iy + radius + 1, game.maze_h)
542
+ win_h = game.video_res
543
+
544
+ # convert eaten coins to a set for faster checking coordinates
545
+ coins_eaten_set = {tuple(coin_coord) for coin_coord in frame.coins_eaten}
546
+
547
+ if not skip_background and not only_mugen:
548
+ for y in range(y_start, y_end):
549
+ for x in range(x_start, x_end):
550
+ wkey = game.maze[y][x]
551
+ if wkey == SPACE:
552
+ continue
553
+
554
+ # eaten coins is treated the same as SPACE, just continue
555
+ # but we should not modify the coins in maze to SPACE, or it may cause inconsistency
556
+ # if we ever need to render backwards or save json after drawing
557
+ if (x, y) in coins_eaten_set:
558
+ continue
559
+
560
+ assert wkey in asset_map, f"{wkey} not in assets!"
561
+
562
+ tile_rect = [
563
+ kx * x + dx - 0.1,
564
+ win_h - ky * y + dy - 0.1,
565
+ kx + 0.5 + 0.2,
566
+ ky + 0.5 + 0.2,
567
+ ]
568
+
569
+ # skip tile if the rect is completely out-of-bounds
570
+ if check_out_of_bounds(tile_rect, img.size):
571
+ continue
572
+
573
+ if wkey == LAVA_MIDDLE or wkey == LAVA_SURFACE:
574
+ d1 = tile_rect[:]
575
+ d2 = tile_rect[:]
576
+ asset_size = asset_map[wkey].asset.size
577
+ sr = [0, 0, asset_size[0], asset_size[1]]
578
+ sr1 = sr[:]
579
+ sr2 = sr[:]
580
+ tr = frame.state_time * 0.1
581
+ tr -= int(tr)
582
+ tr *= -1
583
+ d1[0] += tr * tile_rect[2]
584
+ d2[0] += tile_rect[2] + tr * tile_rect[2]
585
+ sr1[0] += -tr * asset_size[0]
586
+ sr2[0] += -asset_size[0] - tr * asset_size[0]
587
+ d1 = intersect_rects(d1, tile_rect)
588
+ d2 = intersect_rects(d2, tile_rect)
589
+ if d1 is not None:
590
+ d1[2] += 0.5
591
+ if d2 is not None:
592
+ d2[0] -= 0.5
593
+ d2[2] += 0.5
594
+ sr1 = intersect_rects(sr1, sr)
595
+ sr2 = intersect_rects(sr2, sr)
596
+ if sr1 is not None and d1 is not None:
597
+ # crop and render one half of the asset
598
+ crop_mask = asset_map[wkey].asset.crop(
599
+ integer_rect(convert_xywh_to_xyxy(sr1))
600
+ )
601
+ paint_color_in_rect_with_mask(
602
+ img,
603
+ integer_rect(d1),
604
+ asset_map[wkey].semantic_color,
605
+ crop_mask,
606
+ gen_original=gen_original,
607
+ )
608
+ if sr2 is not None and d2 is not None:
609
+ # crop and render the other half of the asset (swapped places horizontally)
610
+ crop_mask = asset_map[wkey].asset.crop(
611
+ integer_rect(convert_xywh_to_xyxy(sr2))
612
+ )
613
+ paint_color_in_rect_with_mask(
614
+ img,
615
+ integer_rect(d2),
616
+ asset_map[wkey].semantic_color,
617
+ crop_mask,
618
+ gen_original=gen_original,
619
+ )
620
+ else:
621
+ paint_color_in_rect_with_mask(
622
+ img,
623
+ integer_rect(tile_rect),
624
+ asset_map[wkey].semantic_color,
625
+ asset_map[wkey].asset,
626
+ gen_original=gen_original,
627
+ )
628
+
629
+ if not skip_foreground:
630
+ if not only_mugen:
631
+ # paint monsters
632
+ for mi in range(len(frame.monsters)):
633
+ if frame.monsters[mi].is_dead:
634
+ dying_frame_cnt = max(0, frame.monsters[mi].monster_dying_frame_cnt)
635
+ monster_shrinkage = (
636
+ (MONSTER_DEATH_ANIM_LENGTH - dying_frame_cnt)
637
+ * 0.8
638
+ / MONSTER_DEATH_ANIM_LENGTH
639
+ )
640
+ monster_rect = [
641
+ math.floor(kx * frame.monsters[mi].x + dx),
642
+ math.floor(
643
+ win_h
644
+ - ky * frame.monsters[mi].y
645
+ + dy
646
+ + ky * monster_shrinkage
647
+ ),
648
+ math.ceil(kx),
649
+ math.ceil(ky * (1 - monster_shrinkage)),
650
+ ]
651
+ else:
652
+ monster_rect = [
653
+ math.floor(kx * frame.monsters[mi].x + dx),
654
+ math.floor(win_h - ky * frame.monsters[mi].y + dy),
655
+ math.ceil(kx),
656
+ math.ceil(ky),
657
+ ]
658
+
659
+ m_name = game.flattened_monster_names[frame.monsters[mi].theme]
660
+ # add pose and facing to the key to find correct asset
661
+ m_pose = "" if frame.monsters[mi].walk1_mode else "_move"
662
+ if frame.monsters[mi].is_dead:
663
+ m_pose = "_dead"
664
+ m_key = (
665
+ m_name + m_pose + ("_right" if frame.monsters[mi].vx > 0 else "")
666
+ )
667
+
668
+ paint_color_in_rect_with_mask(
669
+ img,
670
+ monster_rect,
671
+ asset_map[m_key].semantic_color,
672
+ asset_map[m_key].asset,
673
+ gen_original=gen_original,
674
+ ignore_mask=bbox_smap_for_monsters,
675
+ )
676
+
677
+ if not skip_mugen:
678
+ # paint agent - do it after monsters so agent is always in front
679
+ a_key = (
680
+ alien_name
681
+ + "_"
682
+ + frame.agent.pose
683
+ + ("" if frame.agent.is_facing_right else "_left")
684
+ )
685
+ # note how aspect_ratio is used for alien rect, this can be applied to
686
+ # monster rect to support asset that's not 1:1 (e.g. use alien as monster)
687
+ alien_rect = [
688
+ math.floor(kx * frame.agent.x + dx),
689
+ # math.floor(win_h - ky * (frame.agent.y + 1) + dy), # default for 2:1 alien, no asset swap
690
+ math.floor(
691
+ win_h
692
+ - ky * (frame.agent.y + asset_map[a_key].aspect_ratio - 1)
693
+ + dy
694
+ ),
695
+ math.ceil(kx),
696
+ # math.ceil(2 * ky), # default for 2:1 alien, no asset swap
697
+ math.ceil(asset_map[a_key].aspect_ratio * ky),
698
+ ]
699
+ if frame.agent.is_killed:
700
+ transparency = (
701
+ DEATH_ANIM_LENGTH + 1 - frame.agent.killed_animation_frame_cnt
702
+ ) * 12
703
+ # only render if not fully transparent
704
+ if transparency > 255:
705
+ agent_asset = None
706
+ else:
707
+ if gen_original:
708
+ agent_asset = get_transparent_asset(
709
+ asset_map[a_key].asset, transparency
710
+ )
711
+ else:
712
+ # when generating semantic map, alien mask won't change unless fully transparent
713
+ agent_asset = asset_map[a_key].asset
714
+ else:
715
+ agent_asset = asset_map[a_key].asset
716
+ if agent_asset is not None:
717
+ paint_color_in_rect_with_mask(
718
+ img,
719
+ alien_rect,
720
+ asset_map[a_key].semantic_color,
721
+ agent_asset,
722
+ gen_original=gen_original,
723
+ ignore_mask=bbox_smap_for_agent,
724
+ cut_mask_top_ratio=0.8,
725
+ )
726
+
727
+ # paint the bubble shield if agent is in power-up mode
728
+ if frame.agent.power_up_mode:
729
+ shield_rect = [
730
+ # NOTE: game engine hard-codes 7 and 8 for co-ordinates which won't work with video-res that's not 1024
731
+ # (for training we usually generate with 256 or 128 video_res), so need to convert them
732
+ math.floor(kx * frame.agent.x + dx - 7 * game.video_res / 1024),
733
+ math.floor(
734
+ win_h
735
+ - ky * (frame.agent.y + 1)
736
+ + dy
737
+ + 8 * game.video_res / 1024
738
+ ),
739
+ math.ceil(kx * 1.15),
740
+ math.ceil(ky * 2.1),
741
+ ]
742
+ # pull bubble down when Mugen crouches
743
+ if frame.agent.pose == "duck":
744
+ shield_rect[1] += math.floor(8 * game.video_res / 1024)
745
+
746
+ paint_color_in_rect_with_mask(
747
+ img,
748
+ shield_rect,
749
+ asset_map["shield"].semantic_color,
750
+ asset_map["shield"].asset,
751
+ gen_original=gen_original,
752
+ ignore_mask=bbox_smap_for_agent,
753
+ cut_mask_top_ratio=0.45,
754
+ )
755
+
756
+ return img
multimodal/examples/mugen/data/coinrun/game.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import json
8
+
9
+
10
+ class Game:
11
+ def __init__(self, **kwargs):
12
+ self.game_id = -1
13
+ self.level_seed = 0
14
+ self.rl_agent_seed = 0
15
+ self.zoom = 5.5
16
+ self.bgzoom = 0.4 # NOTE: hard-coded
17
+ self.world_theme_n = -1
18
+ self.agent_theme_n = -1
19
+
20
+ self.background_themes = []
21
+ self.ground_themes = []
22
+ self.agent_themes = []
23
+ self.monster_names = {}
24
+ self.flattened_monster_names = []
25
+
26
+ # TODO: save and load these from the game engine
27
+ self.video_res = 1024
28
+ self.maze_w = 64
29
+ self.maze_h = 13 # for zoom 5.5
30
+
31
+ self.reset_game()
32
+
33
+ self.__dict__.update(**kwargs)
34
+ self.frames = [Frame(**f) for f in self.frames]
35
+
36
+ def reset_game(self):
37
+ self.maze = None
38
+ self.frames = []
39
+
40
+ def asdict(self, f_start=-1, f_end=-1):
41
+ if f_end < 0:
42
+ # show all frames by default
43
+ frames_as_dict = [f.asdict() for f in self.frames]
44
+ else:
45
+ frames_as_dict = [f.asdict() for f in self.frames[f_start:f_end]]
46
+ return {
47
+ "game_id": self.game_id,
48
+ "level_seed": self.level_seed,
49
+ "rl_agent_seed": self.rl_agent_seed,
50
+ "zoom": self.zoom,
51
+ "bgzoom": self.bgzoom,
52
+ "world_theme_n": self.world_theme_n,
53
+ "agent_theme_n": self.agent_theme_n,
54
+ "background_themes": self.background_themes,
55
+ "ground_themes": self.ground_themes,
56
+ "agent_themes": self.agent_themes,
57
+ "monster_names": self.monster_names,
58
+ "video_res": self.video_res,
59
+ "maze_w": self.maze_w,
60
+ "maze_h": self.maze_h,
61
+ "maze": self.maze if self.maze is not None else None,
62
+ "frames": frames_as_dict,
63
+ }
64
+
65
+ def __repr__(self):
66
+ return json.dumps(self.asdict())
67
+
68
+ def save_json(self, json_path, f_start=-1, f_end=-1):
69
+ with open(json_path, "w") as f:
70
+ json.dump(self.asdict(f_start, f_end), f, indent=2)
71
+
72
+ def load_json(self, json_path):
73
+ with open(json_path, "r") as f:
74
+ data = json.load(f)
75
+
76
+ self.reset_game()
77
+ self.__dict__.update(**data)
78
+ self.frames = [Frame(**f) for f in self.frames]
79
+
80
+ self.flatten_monster_names()
81
+ self.reset_eaten_coins()
82
+
83
+ def flatten_monster_names(self):
84
+ # the order is important!
85
+ self.flattened_monster_names = self.monster_names["ground"]
86
+ self.flattened_monster_names.extend(self.monster_names["walking"])
87
+ self.flattened_monster_names.extend(self.monster_names["flying"])
88
+
89
+ # NOTE: some coins might be missing due to how 3s clip json is saved
90
+ # reset all eaten coins to put them back
91
+ # this is a temporary fix until we regenerate all jsons
92
+ def reset_eaten_coins(self):
93
+ for coin_loc in self.frames[-1].coins_eaten:
94
+ # note the game rows are saved as strings
95
+ # NOTE: '1' is the yellow coin, we also has another type '2' that is the red gem
96
+ # but the json with '2' enabled should not have this issue
97
+ if self.maze[coin_loc[1]][coin_loc[0]] == ".":
98
+ self.maze[coin_loc[1]] = (
99
+ self.maze[coin_loc[1]][: coin_loc[0]]
100
+ + "1"
101
+ + self.maze[coin_loc[1]][(coin_loc[0] + 1) :]
102
+ )
103
+
104
+
105
+ class Frame:
106
+ def __init__(self, **kwargs):
107
+ self.frame_id = -1
108
+ self.file_name = ""
109
+ self.state_time = 0
110
+ self.coins_eaten = []
111
+ self.agent = None
112
+ self.monsters = []
113
+
114
+ self.__dict__.update(**kwargs)
115
+ if "agent" in self.__dict__ and self.agent is not None:
116
+ self.agent = Agent(**self.agent)
117
+ if "monsters" in self.__dict__:
118
+ self.monsters = [Monster(**m) for m in self.monsters]
119
+
120
+ def asdict(self):
121
+ return {
122
+ "frame_id": self.frame_id,
123
+ "file_name": self.file_name,
124
+ "state_time": self.state_time,
125
+ "coins_eaten": self.coins_eaten,
126
+ "agent": self.agent.asdict() if self.agent is not None else None,
127
+ "monsters": [m.asdict() for m in self.monsters],
128
+ }
129
+
130
+ def __repr__(self):
131
+ return json.dumps(self.asdict())
132
+
133
+
134
+ class Agent:
135
+ def __init__(
136
+ self,
137
+ x,
138
+ y,
139
+ vx=0.0,
140
+ vy=0.0,
141
+ time_alive=0,
142
+ ladder=False,
143
+ spring=0,
144
+ is_killed=False,
145
+ killed_animation_frame_cnt=0,
146
+ finished_level_frame_cnt=0,
147
+ killed_monster=False,
148
+ bumped_head=False,
149
+ collected_coin=False,
150
+ collected_gem=False,
151
+ power_up_mode=False,
152
+ **kwargs,
153
+ ):
154
+ self.x = x
155
+ self.y = y
156
+ self.vx = vx
157
+ self.vy = vy
158
+ self.time_alive = time_alive
159
+ self.ladder = ladder # for climb pose
160
+ self.spring = spring # for duck pose
161
+
162
+ # states related to agent dying or finishing animations
163
+ self.is_killed = is_killed
164
+ self.killed_animation_frame_cnt = killed_animation_frame_cnt
165
+ self.finished_level_frame_cnt = finished_level_frame_cnt
166
+ self.killed_monster = killed_monster
167
+ self.bumped_head = bumped_head
168
+ self.collected_coin = collected_coin
169
+ self.collected_gem = collected_gem
170
+ self.power_up_mode = power_up_mode
171
+
172
+ self.anim_freq = 5 # hard-coded
173
+
174
+ # decide whether to flip asset horizontally
175
+ self.is_facing_right = True
176
+ if self.vx < 0:
177
+ self.is_facing_right = False
178
+
179
+ # decide which of the two walk/climb asset to use
180
+ self.walk1_mode = True
181
+ if (self.time_alive // self.anim_freq) % 2 != 0:
182
+ self.walk1_mode = False
183
+
184
+ self.pose = self.get_pose()
185
+
186
+ # kwargs are ignored
187
+ # self.__dict__.update(**kwargs)
188
+
189
+ def get_pose(self):
190
+ if self.is_killed:
191
+ return "hit"
192
+ if self.ladder:
193
+ if self.walk1_mode:
194
+ return "climb1"
195
+ else:
196
+ return "climb2"
197
+ if self.vy != 0:
198
+ return "jump"
199
+ if self.spring != 0:
200
+ return "duck"
201
+ if self.vx == 0:
202
+ return "stand"
203
+ if self.walk1_mode:
204
+ return "walk1"
205
+ else:
206
+ return "walk2"
207
+
208
+ def asdict(self):
209
+ return {
210
+ "x": self.x,
211
+ "y": self.y,
212
+ "vx": self.vx,
213
+ "vy": self.vy,
214
+ "time_alive": self.time_alive,
215
+ "ladder": self.ladder,
216
+ "spring": self.spring,
217
+ "is_killed": self.is_killed,
218
+ "killed_animation_frame_cnt": self.killed_animation_frame_cnt,
219
+ "finished_level_frame_cnt": self.finished_level_frame_cnt,
220
+ "killed_monster": self.killed_monster,
221
+ "bumped_head": self.bumped_head,
222
+ "collected_coin": self.collected_coin,
223
+ "collected_gem": self.collected_gem,
224
+ "power_up_mode": self.power_up_mode,
225
+ "anim_freq": self.anim_freq,
226
+ "is_facing_right": self.is_facing_right,
227
+ "walk1_mode": self.walk1_mode,
228
+ "pose": self.pose,
229
+ }
230
+
231
+ def __repr__(self):
232
+ return json.dumps(self.asdict())
233
+
234
+
235
+ class Monster:
236
+ def __init__(
237
+ self,
238
+ m_id,
239
+ x,
240
+ y,
241
+ vx=0.0,
242
+ vy=0.0,
243
+ theme=0,
244
+ is_flying=False,
245
+ is_walking=False,
246
+ is_jumping=False,
247
+ is_dead=False,
248
+ time=0,
249
+ anim_freq=1,
250
+ monster_dying_frame_cnt=0,
251
+ **kwargs,
252
+ ):
253
+ self.m_id = m_id
254
+ self.x = x
255
+ self.y = y
256
+ self.vx = vx
257
+ self.vy = vy
258
+ self.theme = theme # monster type (saw, snail, slime, etc.)
259
+ self.is_flying = is_flying
260
+ self.is_walking = is_walking
261
+ self.is_jumping = is_jumping
262
+ self.is_dead = is_dead
263
+ self.time = time
264
+ self.anim_freq = anim_freq
265
+ self.monster_dying_frame_cnt = monster_dying_frame_cnt
266
+
267
+ # decide which of the two walk/climb asset to use
268
+ self.walk1_mode = True
269
+ if self.is_jumping:
270
+ # for jumping monster, walk1 asset is decided by vertical speed
271
+ if self.vy != 0:
272
+ self.walk1_mode = False
273
+ elif (self.time // self.anim_freq) % 2 != 0:
274
+ self.walk1_mode = False
275
+
276
+ def asdict(self):
277
+ return {
278
+ "m_id": self.m_id,
279
+ "x": self.x,
280
+ "y": self.y,
281
+ "vx": self.vx,
282
+ "vy": self.vy,
283
+ "theme": self.theme,
284
+ "is_flying": self.is_flying,
285
+ "is_walking": self.is_walking,
286
+ "is_jumping": self.is_jumping,
287
+ "is_dead": self.is_dead,
288
+ "time": self.time,
289
+ "anim_freq": self.anim_freq,
290
+ "monster_dying_frame_cnt": self.monster_dying_frame_cnt,
291
+ "walk1_mode": self.walk1_mode,
292
+ }
293
+
294
+ def __repr__(self):
295
+ return json.dumps(self.asdict())
multimodal/examples/mugen/data/coinrun/generate_text_desc.py ADDED
@@ -0,0 +1,435 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import json
8
+
9
+
10
+ class Sequence:
11
+ def __init__(
12
+ self, start_frame, end_frame, pose_type, start_x, start_y, end_x, end_y
13
+ ):
14
+ self.start_frame = start_frame
15
+ self.end_frame = end_frame
16
+
17
+ # 'ground' includes 'walk', 'duck', 'stand'; other types are 'climb', 'jump', 'hit'
18
+ self.pose_type = pose_type
19
+ self.start_x = start_x
20
+ self.start_y = start_y
21
+ self.end_x = end_x
22
+ self.end_y = end_y
23
+ self.time_jumps = 1 if pose_type == "jump" else 0
24
+ self.end_maze_above = "."
25
+ self.end_maze_below = "."
26
+ self.num_coins_eaten = 0
27
+ self.num_gems_eaten = 0
28
+ self.start_shield = False
29
+ self.end_shield = False
30
+ self.changed_shield = False
31
+ self.killed_monsters = []
32
+ self.jump_over_monsters = []
33
+ self.killed_by = ""
34
+ self.text_desc = ""
35
+
36
+ # Decide graduarity of text description (skip sequence shorter than this)
37
+ self.min_len_for_text_desc = 5
38
+
39
+ def asdict(self):
40
+ return {
41
+ "start_frame": self.start_frame,
42
+ "end_frame": self.end_frame,
43
+ "pose_type": self.pose_type,
44
+ "start_xy": (self.start_x, self.start_y),
45
+ "end_xy": (self.end_x, self.end_y),
46
+ "bumped_head": self.is_bumped_head(),
47
+ "same_level_jump": self.is_same_level_jump(),
48
+ "num_coins_eaten": self.num_coins_eaten,
49
+ "num_gems_eaten": self.num_gems_eaten,
50
+ "start_shield": self.start_shield,
51
+ "end_shield": self.end_shield,
52
+ "changed_shield": self.changed_shield,
53
+ "killed_monsters": self.killed_monsters,
54
+ "jump_over_monsters": self.jump_over_monsters,
55
+ "killed_by": self.killed_by,
56
+ "text_desc": self.text_desc,
57
+ }
58
+
59
+ def __repr__(self):
60
+ return json.dumps(self.asdict())
61
+
62
+ # bumped head will show as 'walk' pose and last for 1-2 frames
63
+ def is_bumped_head(self):
64
+ if (
65
+ self.pose_type == "ground"
66
+ and (self.end_frame - self.start_frame <= 1)
67
+ and self.end_maze_below in ".12"
68
+ ): # and self.end_maze_above in 'SAab'
69
+ return True
70
+ return False
71
+
72
+ def is_same_level_jump(self):
73
+ if self.pose_type == "jump" and abs(self.end_y - self.start_y) <= 0.5:
74
+ return True
75
+ return False
76
+
77
+ def merge_sequences(self, sequences):
78
+ self.end_frame = sequences[-1].end_frame
79
+ self.end_x = sequences[-1].end_x
80
+ self.end_y = sequences[-1].end_y
81
+ self.end_maze_above = sequences[-1].end_maze_above
82
+ self.end_maze_below = sequences[-1].end_maze_below
83
+ for seq in sequences:
84
+ if seq.is_bumped_head():
85
+ self.time_jumps -= 1
86
+ self.time_jumps += seq.time_jumps
87
+
88
+ self.num_coins_eaten += seq.num_coins_eaten
89
+ self.num_gems_eaten += seq.num_gems_eaten
90
+ self.killed_monsters.extend(seq.killed_monsters)
91
+ self.jump_over_monsters.extend(seq.jump_over_monsters)
92
+
93
+ def process_metadata(self, game):
94
+ # generate game.flattened_monster_names if not already
95
+ # this is used to get monster names
96
+ if len(game.flattened_monster_names) == 0:
97
+ game.flatten_monster_names()
98
+
99
+ # count number of coins and gems eaten during the sequence
100
+ # start from one frame earlier (if not 0) so we can get change in the first frame
101
+ start_frame_id = max(self.start_frame - 1, 0)
102
+ if len(game.frames[self.end_frame].coins_eaten) > len(
103
+ game.frames[start_frame_id].coins_eaten
104
+ ):
105
+ start_coin_set = {
106
+ (coord[0], coord[1])
107
+ for coord in game.frames[start_frame_id].coins_eaten
108
+ }
109
+ end_coin_set = {
110
+ (coord[0], coord[1])
111
+ for coord in game.frames[self.end_frame].coins_eaten
112
+ }
113
+ new_coins_eaten = end_coin_set - start_coin_set
114
+ for coin_coord in new_coins_eaten:
115
+ if game.maze[coin_coord[1]][coin_coord[0]] == "2":
116
+ self.num_gems_eaten += 1
117
+ else:
118
+ self.num_coins_eaten += 1
119
+
120
+ # check if Mugen changes between shield up and down mode during the sequence
121
+ self.start_shield = game.frames[self.start_frame].agent.power_up_mode
122
+ self.end_shield = game.frames[self.end_frame].agent.power_up_mode
123
+ shield_up_mode = False
124
+ shield_down_mode = False
125
+ for frame_id in range(self.start_frame, self.end_frame + 1):
126
+ if game.frames[frame_id].agent.power_up_mode:
127
+ shield_up_mode = True
128
+ else:
129
+ shield_down_mode = True
130
+ if shield_up_mode and shield_down_mode:
131
+ self.changed_shield = True
132
+
133
+ end_frame_id = min(self.end_frame + 2, len(game.frames))
134
+ for frame_id in range(self.start_frame, end_frame_id):
135
+ frame = game.frames[frame_id]
136
+ dead_monsters = set()
137
+ for i, m in enumerate(frame.monsters):
138
+ if m.is_dead:
139
+ dead_monsters.add(i)
140
+ # if more monsters are killed, record the monster killed and the frame id
141
+ if frame_id > self.start_frame and len(dead_monsters) > len(
142
+ prev_dead_monsters
143
+ ):
144
+ killed_monster_theme = frame.monsters[
145
+ list(dead_monsters - prev_dead_monsters)[0]
146
+ ].theme
147
+ self.killed_monsters.append(
148
+ game.flattened_monster_names[killed_monster_theme]
149
+ )
150
+ prev_dead_monsters = dead_monsters.copy()
151
+
152
+ # figure out which monster killed Mugen
153
+ killed_by_m_id = -1
154
+ if self.pose_type == "hit":
155
+ # check the monster distance in the first frame of hit sequence
156
+ m_min_dist = 1000 # just put some random large dist here
157
+ for m in game.frames[self.start_frame].monsters:
158
+ x_dist = self.start_x - m.x
159
+ y_dist = self.start_y - m.y
160
+ m_dist = x_dist * x_dist + y_dist * y_dist
161
+ if m_dist < m_min_dist:
162
+ killed_by_m_id = m.theme
163
+ m_min_dist = m_dist
164
+ if killed_by_m_id != -1:
165
+ self.killed_by = game.flattened_monster_names[killed_by_m_id]
166
+
167
+ # check for monsters jumped over
168
+ if self.pose_type == "jump":
169
+ # for purpose of checking jumped over monsters,
170
+ # ground y is fixed at the y coordinate of the previous frame
171
+ # note for jump sequence, start_y already recorded the location before jump starts
172
+ ground_y = round(self.start_y)
173
+ jump_over_monsters_set = set()
174
+ for frame_id in range(self.start_frame, self.end_frame + 1):
175
+ frame = game.frames[frame_id]
176
+ # this is the location below the agent at the same y level when jump starts
177
+ ground_loc = (round(frame.agent.x), ground_y)
178
+ for i, m in enumerate(frame.monsters):
179
+ if (round(m.x), round(m.y)) == ground_loc:
180
+ # use set to avoid adding duplicates
181
+ jump_over_monsters_set.add(i)
182
+
183
+ # now convert these into names, but only keep those that's still not killed by the next frame
184
+ for m_i in jump_over_monsters_set:
185
+ if not game.frames[end_frame_id - 1].monsters[m_i].is_dead:
186
+ self.jump_over_monsters.append(
187
+ game.flattened_monster_names[frame.monsters[m_i].theme]
188
+ )
189
+
190
+ def generate_text_desc(self):
191
+ # only generate if sequence is long enough
192
+ if self.end_frame - self.start_frame < self.min_len_for_text_desc:
193
+ self.text_desc = ""
194
+ elif self.pose_type == "hit":
195
+ if self.killed_by != "":
196
+ self.text_desc = f"killed by a {self.killed_by}"
197
+ else:
198
+ self.text_desc = "killed by a monster"
199
+ else:
200
+ y_direct = ""
201
+ if self.end_y - self.start_y > 0.5:
202
+ y_direct = " up"
203
+ elif self.start_y - self.end_y > 0.5:
204
+ y_direct = " down"
205
+ else:
206
+ y_direct = " a bit" if self.pose_type == "ground" else ""
207
+ x_direct = ""
208
+ if self.end_x - self.start_x > 0.5:
209
+ x_direct = " to the right"
210
+ elif self.start_x - self.end_x > 0.5:
211
+ x_direct = " to the left"
212
+ else:
213
+ x_direct = " a bit" if self.pose_type == "ground" else ""
214
+
215
+ if self.pose_type == "climb":
216
+ self.text_desc = f"climbs{y_direct} on a ladder"
217
+ elif self.pose_type == "ground":
218
+ self.text_desc = f"walks{x_direct}" # TODO: add random verbs
219
+ elif self.pose_type == "jump":
220
+ jump_time_desc = ""
221
+ if self.time_jumps >= 2:
222
+ jump_time_desc = " a few times"
223
+
224
+ # only add jump destination if it's not a same level jump
225
+ jump_dest_desc = ""
226
+ if y_direct != "":
227
+ if self.end_maze_below in "SAab":
228
+ if self.end_y < 1.5:
229
+ jump_dest_desc = " to the ground"
230
+ else:
231
+ jump_dest_desc = " to a platform"
232
+ elif self.end_maze_below in "#$&%":
233
+ jump_dest_desc = " to a crate"
234
+ elif self.end_maze_below == "=":
235
+ jump_dest_desc = " to a ladder"
236
+
237
+ # add desc for monsters jumped over
238
+ jumped_over_desc = ""
239
+ if len(self.jump_over_monsters) > 0:
240
+ jumped_over_desc = " over a " + " and a ".join(
241
+ self.jump_over_monsters
242
+ )
243
+
244
+ self.text_desc = f"jumps{y_direct}{jump_time_desc}{x_direct}{jumped_over_desc}{jump_dest_desc}"
245
+
246
+ if self.num_coins_eaten > 0 or self.num_gems_eaten > 0:
247
+ self.text_desc += self.generate_collect_coin_desc()
248
+
249
+ if len(self.killed_monsters) > 0:
250
+ self.text_desc += " and killed a " + " and a ".join(
251
+ self.killed_monsters
252
+ )
253
+
254
+ def generate_collect_coin_desc(self):
255
+ if self.num_coins_eaten == 0 and self.num_gems_eaten == 0:
256
+ return ""
257
+
258
+ coin_descs = []
259
+ # add coin description if collected at least one coin
260
+ if self.num_coins_eaten == 1:
261
+ coin_descs.append(" a coin")
262
+ elif self.num_coins_eaten > 1:
263
+ coin_descs.append(" a few coins")
264
+
265
+ # add gem description if collected at least one gem
266
+ if self.num_gems_eaten == 1:
267
+ coin_descs.append(" a gem")
268
+ elif self.num_gems_eaten > 1:
269
+ coin_descs.append(" a few gems")
270
+
271
+ # connects descriptions for coins and gems with 'and'
272
+ coin_descs = " and".join(coin_descs)
273
+
274
+ # shield change should only be a result of eating gem or coin
275
+ if self.changed_shield:
276
+ coin_descs += self.generate_shield_desc()
277
+
278
+ return f" and collects{coin_descs}"
279
+
280
+ def generate_shield_desc(self):
281
+ if not self.start_shield and self.end_shield:
282
+ return " to turn on the shield"
283
+ elif self.start_shield and not self.end_shield:
284
+ return " to turn off the shield"
285
+ else:
286
+ # start and end in the same shield state but still changed shield during sequence
287
+ if self.start_shield:
288
+ return " to turn shield off then on again"
289
+ else:
290
+ return " to turn shield on then off again"
291
+
292
+
293
+ def process_sequence(game, curr_pose_type, start_i, curr_i, last_seq=False):
294
+ # different type of pose, construct a sequence
295
+ # for 'jump', the start and end location is based on frame before the first and after the last frame
296
+ # for others, it's the first and last frame
297
+ if curr_pose_type == "jump":
298
+ pos_start_frame = max(start_i - 1, 0)
299
+ pos_end_frame = curr_i
300
+ else:
301
+ pos_start_frame = start_i
302
+ # curr_i will be one frame after, unless it's the last sequence of video
303
+ # however, for jump sequence, we do want one frame after to know where jump lands
304
+ pos_end_frame = curr_i - 1 if not last_seq else curr_i
305
+
306
+ seq = Sequence(
307
+ start_frame=start_i,
308
+ end_frame=curr_i - 1 if not last_seq else curr_i,
309
+ pose_type=curr_pose_type,
310
+ start_x=game.frames[pos_start_frame].agent.x,
311
+ start_y=game.frames[pos_start_frame].agent.y,
312
+ end_x=game.frames[pos_end_frame].agent.x,
313
+ end_y=game.frames[pos_end_frame].agent.y,
314
+ )
315
+ seq.end_maze_above = game.maze[round(seq.end_y) + 1][round(seq.end_x)]
316
+ seq.end_maze_below = game.maze[round(seq.end_y) - 1][round(seq.end_x)]
317
+ # sometimes jump may end a bit over the edge of cliff, this is to catch and fix that
318
+ if curr_pose_type == "jump" and seq.end_maze_below in ".12":
319
+ neighbor_x = (
320
+ int(seq.end_x) * 2 + 1 - round(seq.end_x)
321
+ ) # get the opposite of round()
322
+ seq.end_maze_below = game.maze[round(seq.end_y) - 1][neighbor_x]
323
+
324
+ return seq
325
+
326
+
327
+ def convert_game_to_text_desc(game, start_idx=0, end_idx=-1, alien_name="Mugen"):
328
+ if alien_name is None:
329
+ alien_name = "Mugen"
330
+
331
+ # if end_idx is not specified, set it to end of the game level
332
+ if end_idx == -1:
333
+ end_idx = len(game.frames)
334
+ start_idx = max(0, start_idx)
335
+ end_idx = min(len(game.frames), end_idx)
336
+
337
+ sequences = []
338
+ for i, f in enumerate(game.frames[start_idx:end_idx]):
339
+ pose = f.agent.pose.strip("12")
340
+ if pose in ["walk", "duck", "stand"]:
341
+ pose_type = "ground"
342
+ else:
343
+ pose_type = pose
344
+ if i == 0:
345
+ # first frame, initialize some status
346
+ start_i = 0
347
+ curr_pose_type = pose_type
348
+ continue
349
+
350
+ if pose_type == curr_pose_type:
351
+ # same type of pose, same sequence
352
+ continue
353
+ else:
354
+ seq = process_sequence(
355
+ game, curr_pose_type, start_idx + start_i, start_idx + i, last_seq=False
356
+ )
357
+ sequences.append(seq)
358
+ start_i = i
359
+ curr_pose_type = pose_type
360
+
361
+ # add the last leftover sequence
362
+ seq = process_sequence(
363
+ game, curr_pose_type, start_idx + start_i, start_idx + i, last_seq=True
364
+ )
365
+ sequences.append(seq)
366
+
367
+ # collapse two jumps into one sequence
368
+ # first pass, merge jumps before and after bumped head, this is to correctly identify jumps at the same level
369
+ seq_i = 0
370
+ reduced_sequences = []
371
+ while seq_i < len(sequences):
372
+ if seq_i == 0 or seq_i == len(sequences) - 1:
373
+ reduced_sequences.append(sequences[seq_i])
374
+ seq_i += 1
375
+ elif (
376
+ sequences[seq_i].is_bumped_head()
377
+ and reduced_sequences[-1].pose_type == "jump"
378
+ and sequences[seq_i + 1].pose_type == "jump"
379
+ ):
380
+ # in case of bumped head, merge the jumps before and after
381
+ reduced_sequences[-1].merge_sequences(sequences[seq_i : seq_i + 2])
382
+ seq_i += 2
383
+ else:
384
+ reduced_sequences.append(sequences[seq_i])
385
+ seq_i += 1
386
+ sequences = reduced_sequences
387
+
388
+ # second pass, collapse two jumps into one sequence if they're both same level jumps
389
+ # jump up and down are not merged (unless it's separated by bumped head that will be merged in first pass)
390
+ result_sequences = []
391
+ seq_i = 0
392
+ max_ground_seq_len_to_merge = 5
393
+ while seq_i < len(sequences):
394
+ # only merge if it's a 'ground' sequence, and before/after are both jumps
395
+ if (
396
+ sequences[seq_i].pose_type != "ground"
397
+ or seq_i == 0
398
+ or seq_i == len(sequences) - 1
399
+ ):
400
+ result_sequences.append(sequences[seq_i])
401
+ seq_i += 1
402
+ elif (
403
+ result_sequences[-1].pose_type != "jump"
404
+ or sequences[seq_i + 1].pose_type != "jump"
405
+ ):
406
+ result_sequences.append(sequences[seq_i])
407
+ seq_i += 1
408
+ elif (
409
+ result_sequences[-1].is_same_level_jump()
410
+ and sequences[seq_i + 1].is_same_level_jump()
411
+ and (
412
+ sequences[seq_i].end_frame - sequences[seq_i].start_frame
413
+ < max_ground_seq_len_to_merge
414
+ )
415
+ ):
416
+ # not bumped head, then only merge if sequence is short enough, and both jumps are the same level
417
+ result_sequences[-1].merge_sequences(sequences[seq_i : seq_i + 2])
418
+ seq_i += 2
419
+ else:
420
+ result_sequences.append(sequences[seq_i])
421
+ seq_i += 1
422
+ sequences = result_sequences
423
+
424
+ # generate text description for each sequence
425
+ text_descriptions = []
426
+ for seq in sequences:
427
+ seq.process_metadata(game)
428
+ seq.generate_text_desc()
429
+ if seq.text_desc != "":
430
+ text_descriptions.append(seq.text_desc)
431
+
432
+ # add Mugen in the beginning, then concat by 'and'
433
+ final_text_desc = alien_name + " " + ", and ".join(text_descriptions)
434
+
435
+ return final_text_desc
multimodal/examples/mugen/data/mugen_datamodules.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Callable, Optional
8
+
9
+ import pytorch_lightning as pl
10
+ import torch
11
+ import torch.distributed as dist
12
+ import torch.utils.data as data
13
+
14
+ from .mugen_dataset import MUGENDataset, MUGENDatasetArgs
15
+
16
+
17
+ class MUGENDataModule(pl.LightningDataModule):
18
+ """General lightning data module for MUGEN dataset.
19
+
20
+ Args:
21
+ mugen_dataset_args (MUGENDatasetArgs): arguments for MUGENDataset.
22
+ text_transform (Optional[Callable]): transform for text batches.
23
+ Only used when not ``None`` and when ``mugen_dataset_args.get_text_desc = True``.
24
+ Defaults to ``None``.
25
+ video_transform (Optional[Callable]): transform for video batches.
26
+ Only used when not ``None`` and when ``mugen_dataset_args.get_game_frame = True``.
27
+ Defaults to ``None``.
28
+ audio_transform (Optional[Callable]): transform for audio batches.
29
+ Only used when not ``None`` and when ``mugen_dataset_args.get_audio = True``.
30
+ Defaults to ``None``.
31
+ batch_size (int): number of samples per batch.
32
+ Defaults to ``16``.
33
+ num_workers (int): number of subprocesses for data loading.
34
+ Defaults to ``0``, meaning data is loaded in the main process.
35
+ shuffle (bool): whether to reshuffle data after each epoch.
36
+ Defaults to ``True``.
37
+ """
38
+
39
+ def __init__(
40
+ self,
41
+ mugen_dataset_args: MUGENDatasetArgs,
42
+ text_transform: Optional[Callable] = None,
43
+ video_transform: Optional[Callable] = None,
44
+ audio_transform: Optional[Callable] = None,
45
+ batch_size: int = 16,
46
+ num_workers: int = 0,
47
+ shuffle: bool = True,
48
+ ):
49
+ super().__init__()
50
+ self.mugen_dataset_args = mugen_dataset_args
51
+ self.text_transform = text_transform
52
+ self.video_transform = video_transform
53
+ self.audio_transform = audio_transform
54
+ self.batch_size = batch_size
55
+ self.num_workers = num_workers
56
+ self.shuffle = shuffle
57
+
58
+ @property
59
+ def n_classes(self):
60
+ dataset = self._dataset(True)
61
+ return dataset.n_classes
62
+
63
+ def _custom_collate_fn(self, batch):
64
+ collated_batch = {}
65
+ if self.mugen_dataset_args.get_game_frame:
66
+ video = [elem["video"] for elem in batch]
67
+ video = torch.stack(video)
68
+ video = self.video_transform(video) if self.video_transform else video
69
+ collated_batch["video"] = video
70
+ if self.mugen_dataset_args.get_text_desc:
71
+ text = [elem["text"] for elem in batch]
72
+ # cannot be torch.stack'ed because still in raw text form, not Tensor
73
+ text = self.text_transform(text) if self.text_transform else text
74
+ collated_batch["text"] = text
75
+ if self.mugen_dataset_args.get_audio:
76
+ audio = [elem["audio"] for elem in batch]
77
+ audio = torch.stack(audio)
78
+ audio = self.audio_transform(audio) if self.audio_transform else audio
79
+ collated_batch["audio"] = audio
80
+ return collated_batch
81
+
82
+ def _dataset(self, split):
83
+ dataset = MUGENDataset(args=self.mugen_dataset_args, split=split)
84
+ return dataset
85
+
86
+ def _dataloader(self, split):
87
+ dataset = self._dataset(split)
88
+ if dist.is_initialized():
89
+ sampler = data.distributed.DistributedSampler(
90
+ dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank()
91
+ )
92
+ else:
93
+ sampler = None
94
+ dataloader = data.DataLoader(
95
+ dataset,
96
+ batch_size=self.batch_size,
97
+ num_workers=self.num_workers,
98
+ pin_memory=True,
99
+ sampler=sampler,
100
+ shuffle=sampler is None and self.shuffle is True,
101
+ collate_fn=self._custom_collate_fn,
102
+ )
103
+ return dataloader
104
+
105
+ def train_dataloader(self):
106
+ return self._dataloader("train")
107
+
108
+ def val_dataloader(self):
109
+ return self._dataloader("val")
110
+
111
+ def test_dataloader(self):
112
+ return self._dataloader("test")
multimodal/examples/mugen/generation/LoadAndComparePretrainedVQVAE.ipynb ADDED
@@ -0,0 +1,383 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "ee3d68e4",
6
+ "metadata": {},
7
+ "source": [
8
+ "# Compare MUGEN's Video VQVAE with TorchMultimodal's\n",
9
+ "\n",
10
+ "This notebook loads the public MUGEN checkpoint for Video VQVAE, remaps the state_dict, and loads it into TorchMultimodal's Video VQVAE to ensure the outputs match. "
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "markdown",
15
+ "id": "5af9d001",
16
+ "metadata": {},
17
+ "source": [
18
+ "### Set directories\n",
19
+ "\n",
20
+ "Replace these with your local directories."
21
+ ]
22
+ },
23
+ {
24
+ "cell_type": "code",
25
+ "execution_count": 2,
26
+ "id": "071c8b48",
27
+ "metadata": {},
28
+ "outputs": [],
29
+ "source": [
30
+ "checkpoint_dir = '/Users/rafiayub/checkpoints/'\n",
31
+ "repo_dir = '/Users/rafiayub/mugen/'\n",
32
+ "home_dir = '/Users/rafiayub/'"
33
+ ]
34
+ },
35
+ {
36
+ "cell_type": "markdown",
37
+ "id": "a3a0f19f",
38
+ "metadata": {},
39
+ "source": [
40
+ "### Clone MUGEN's repo"
41
+ ]
42
+ },
43
+ {
44
+ "cell_type": "code",
45
+ "execution_count": null,
46
+ "id": "83812502",
47
+ "metadata": {},
48
+ "outputs": [],
49
+ "source": [
50
+ "!git clone https://github.com/mugen-org/MUGEN_baseline.git $repo_dir"
51
+ ]
52
+ },
53
+ {
54
+ "cell_type": "markdown",
55
+ "id": "07757cfa",
56
+ "metadata": {},
57
+ "source": [
58
+ "### Download and unzip checkpoints\n",
59
+ "\n",
60
+ "This will take some time."
61
+ ]
62
+ },
63
+ {
64
+ "cell_type": "code",
65
+ "execution_count": null,
66
+ "id": "d41a0c86",
67
+ "metadata": {},
68
+ "outputs": [],
69
+ "source": [
70
+ "!wget https://dl.noahmt.com/creativity/data/MUGEN_release/checkpoints.zip -P $checkpoint_dir"
71
+ ]
72
+ },
73
+ {
74
+ "cell_type": "code",
75
+ "execution_count": null,
76
+ "id": "01d9638a",
77
+ "metadata": {},
78
+ "outputs": [],
79
+ "source": [
80
+ "import os\n",
81
+ "\n",
82
+ "# Unzip checkpoints\n",
83
+ "zip_location = os.path.join(checkpoint_dir, 'checkpoints.zip')\n",
84
+ "!unzip $zip_location -d $checkpoint_dir"
85
+ ]
86
+ },
87
+ {
88
+ "cell_type": "markdown",
89
+ "id": "f06c8938",
90
+ "metadata": {},
91
+ "source": [
92
+ "### Load checkpoint into MUGEN model"
93
+ ]
94
+ },
95
+ {
96
+ "cell_type": "code",
97
+ "execution_count": 3,
98
+ "id": "f3e74b3a",
99
+ "metadata": {},
100
+ "outputs": [],
101
+ "source": [
102
+ "import sys\n",
103
+ "import os\n",
104
+ "sys.path.append(home_dir)\n",
105
+ "\n",
106
+ "import torch\n",
107
+ "from torch import nn\n",
108
+ "import mugen\n",
109
+ "\n",
110
+ "ckpt = torch.load(\n",
111
+ " os.path.join(checkpoint_dir, 'generation/video_vqvae/L32/epoch=54-step=599999.ckpt'), \n",
112
+ " map_location=torch.device('cpu')\n",
113
+ ")"
114
+ ]
115
+ },
116
+ {
117
+ "cell_type": "markdown",
118
+ "id": "3ea6d13e",
119
+ "metadata": {},
120
+ "source": [
121
+ "The arguments are taken from MUGEN's training scripts found at: https://github.com/mugen-org/MUGEN_baseline/blob/main/generation/experiments/vqvae/VideoVQVAE_L32.sh"
122
+ ]
123
+ },
124
+ {
125
+ "cell_type": "code",
126
+ "execution_count": 4,
127
+ "id": "f81bea2e",
128
+ "metadata": {},
129
+ "outputs": [],
130
+ "source": [
131
+ "class Namespace:\n",
132
+ " def __init__(self, **kwargs):\n",
133
+ " self.__dict__.update(kwargs)\n",
134
+ "\n",
135
+ "\n",
136
+ "vqvae_args=Namespace(\n",
137
+ " embedding_dim=256,\n",
138
+ " n_codes=2048,\n",
139
+ " n_hiddens=240,\n",
140
+ " n_res_layers=4,\n",
141
+ " lr=0.0003,\n",
142
+ " downsample=(4, 32, 32),\n",
143
+ " kernel_size=3,\n",
144
+ " sequence_length=16,\n",
145
+ " resolution=256,\n",
146
+ ")\n",
147
+ "vv_mugen = mugen.VQVAE(vqvae_args)"
148
+ ]
149
+ },
150
+ {
151
+ "cell_type": "code",
152
+ "execution_count": 5,
153
+ "id": "fbdcf1f6",
154
+ "metadata": {},
155
+ "outputs": [
156
+ {
157
+ "data": {
158
+ "text/plain": [
159
+ "<All keys matched successfully>"
160
+ ]
161
+ },
162
+ "execution_count": 5,
163
+ "metadata": {},
164
+ "output_type": "execute_result"
165
+ }
166
+ ],
167
+ "source": [
168
+ "vv_mugen.load_state_dict(ckpt['state_dict'])"
169
+ ]
170
+ },
171
+ {
172
+ "cell_type": "markdown",
173
+ "id": "a6bfb325",
174
+ "metadata": {},
175
+ "source": [
176
+ "### Create TorchMultimodal's Video VQVAE"
177
+ ]
178
+ },
179
+ {
180
+ "cell_type": "code",
181
+ "execution_count": 6,
182
+ "id": "74e6bd54",
183
+ "metadata": {},
184
+ "outputs": [],
185
+ "source": [
186
+ "from examples.mugen.generation.video_vqvae import video_vqvae_mugen\n",
187
+ "\n",
188
+ "vv_torchmm = video_vqvae_mugen(pretrained_model_key=None)"
189
+ ]
190
+ },
191
+ {
192
+ "cell_type": "markdown",
193
+ "id": "e612d831",
194
+ "metadata": {},
195
+ "source": [
196
+ "### Remap MUGEN's state_dict and load into new model"
197
+ ]
198
+ },
199
+ {
200
+ "cell_type": "code",
201
+ "execution_count": 7,
202
+ "id": "5f4d4774",
203
+ "metadata": {},
204
+ "outputs": [],
205
+ "source": [
206
+ "import re\n",
207
+ "\n",
208
+ "def map_state_dict(state_dict):\n",
209
+ " mapped_state_dict = {}\n",
210
+ " dim_map = {'w': '2', 'h': '1', 't': '0'}\n",
211
+ " layer_map = {'w_qs': 'query', 'w_ks': 'key', 'w_vs': 'value', 'fc': 'output'}\n",
212
+ " for param, val in state_dict.items():\n",
213
+ " new_param = param\n",
214
+ " res = re.search('encoder.convs.', param)\n",
215
+ " if res:\n",
216
+ " idx = res.end()\n",
217
+ " layer_id = int(param[idx])\n",
218
+ " new_param = param[:idx] + str(layer_id * 2) + param[idx+1:]\n",
219
+ " mapped_state_dict[new_param] = val\n",
220
+ " continue\n",
221
+ " res = re.search('encoder.conv_last', param)\n",
222
+ " if res:\n",
223
+ " idx = res.start() + len('encoder.')\n",
224
+ " new_param = param[:idx] + 'convs.10' + param[res.end():]\n",
225
+ " mapped_state_dict[new_param] = val\n",
226
+ " continue\n",
227
+ " res = re.search('attn_[w,h,t]\\..*\\.', param)\n",
228
+ " if res:\n",
229
+ " dim = param[res.start()+5]\n",
230
+ " new_dim = dim_map[dim]\n",
231
+ " layer = param[res.start()+7:res.end()-1]\n",
232
+ " new_layer = layer_map[layer]\n",
233
+ " new_param = param[:res.start()] + 'mha_attns.' + new_dim + '.' + new_layer + '.' + param[res.end():]\n",
234
+ " mapped_state_dict[new_param] = val\n",
235
+ " continue\n",
236
+ " res = re.search('pre_vq_conv', param)\n",
237
+ " if res:\n",
238
+ " new_param = 'encoder.conv_out' + param[res.end():]\n",
239
+ " mapped_state_dict[new_param] = val\n",
240
+ " continue\n",
241
+ " res = re.search('post_vq_conv', param)\n",
242
+ " if res:\n",
243
+ " new_param = 'decoder.conv_in' + param[res.end():]\n",
244
+ " mapped_state_dict[new_param] = val\n",
245
+ " continue\n",
246
+ " res = re.search('decoder.convts.', param)\n",
247
+ " if res:\n",
248
+ " idx = res.end()\n",
249
+ " layer_id = int(param[idx])\n",
250
+ " new_param = param[:idx] + str(layer_id * 2) + param[idx+1:]\n",
251
+ " mapped_state_dict[new_param] = val\n",
252
+ " continue\n",
253
+ " if param == 'codebook.N':\n",
254
+ " new_param = 'codebook.code_usage'\n",
255
+ " mapped_state_dict[new_param] = val\n",
256
+ " continue\n",
257
+ " if param == 'codebook.z_avg':\n",
258
+ " new_param = 'codebook.code_avg'\n",
259
+ " mapped_state_dict[new_param] = val\n",
260
+ " continue\n",
261
+ " if param == 'codebook.embeddings':\n",
262
+ " new_param = 'codebook.embedding'\n",
263
+ " mapped_state_dict[new_param] = val\n",
264
+ " continue\n",
265
+ " \n",
266
+ " mapped_state_dict[new_param] = val\n",
267
+ " \n",
268
+ " return mapped_state_dict"
269
+ ]
270
+ },
271
+ {
272
+ "cell_type": "code",
273
+ "execution_count": 8,
274
+ "id": "38234858",
275
+ "metadata": {},
276
+ "outputs": [],
277
+ "source": [
278
+ "new_state_dict = map_state_dict(ckpt['state_dict'])"
279
+ ]
280
+ },
281
+ {
282
+ "cell_type": "code",
283
+ "execution_count": 9,
284
+ "id": "e160fb51",
285
+ "metadata": {
286
+ "scrolled": false
287
+ },
288
+ "outputs": [
289
+ {
290
+ "data": {
291
+ "text/plain": [
292
+ "<All keys matched successfully>"
293
+ ]
294
+ },
295
+ "execution_count": 9,
296
+ "metadata": {},
297
+ "output_type": "execute_result"
298
+ }
299
+ ],
300
+ "source": [
301
+ "vv_torchmm.load_state_dict(new_state_dict)"
302
+ ]
303
+ },
304
+ {
305
+ "cell_type": "markdown",
306
+ "id": "46d58eb7",
307
+ "metadata": {},
308
+ "source": [
309
+ "### Compare outputs with a random input"
310
+ ]
311
+ },
312
+ {
313
+ "cell_type": "code",
314
+ "execution_count": 10,
315
+ "id": "3c85cdd3",
316
+ "metadata": {},
317
+ "outputs": [
318
+ {
319
+ "name": "stdout",
320
+ "output_type": "stream",
321
+ "text": [
322
+ "Max difference between outputs: 3.0875205993652344e-05\n",
323
+ "Mean difference between outputs: 1.7353995929170196e-07\n"
324
+ ]
325
+ }
326
+ ],
327
+ "source": [
328
+ "torch.manual_seed(4)\n",
329
+ "video = torch.randn(1,3,32,256,256) # b, c, t, h, w\n",
330
+ "\n",
331
+ "vv_mugen.eval()\n",
332
+ "vv_torchmm.eval()\n",
333
+ "\n",
334
+ "loss, x_recon, codebook_output = vv_mugen(video)\n",
335
+ "output = vv_torchmm(video)\n",
336
+ "\n",
337
+ "diff = abs(output.decoded - x_recon)\n",
338
+ "print(f'Max difference between outputs: {torch.max(diff).item()}')\n",
339
+ "print(f'Mean difference between outputs: {torch.mean(diff).item()}')"
340
+ ]
341
+ },
342
+ {
343
+ "cell_type": "markdown",
344
+ "id": "fa78569e",
345
+ "metadata": {},
346
+ "source": [
347
+ "### Save mapped checkpoint"
348
+ ]
349
+ },
350
+ {
351
+ "cell_type": "code",
352
+ "execution_count": 9,
353
+ "id": "48651d44",
354
+ "metadata": {},
355
+ "outputs": [],
356
+ "source": [
357
+ "save_path = '/Users/rafiayub/checkpoints/generation/video_vqvae/mugen_video_vqvae_L32.pt'\n",
358
+ "torch.save(new_state_dict, save_path)"
359
+ ]
360
+ }
361
+ ],
362
+ "metadata": {
363
+ "kernelspec": {
364
+ "display_name": "Python 3 (ipykernel)",
365
+ "language": "python",
366
+ "name": "python3"
367
+ },
368
+ "language_info": {
369
+ "codemirror_mode": {
370
+ "name": "ipython",
371
+ "version": 3
372
+ },
373
+ "file_extension": ".py",
374
+ "mimetype": "text/x-python",
375
+ "name": "python",
376
+ "nbconvert_exporter": "python",
377
+ "pygments_lexer": "ipython3",
378
+ "version": "3.9.12"
379
+ }
380
+ },
381
+ "nbformat": 4,
382
+ "nbformat_minor": 5
383
+ }
multimodal/examples/mugen/generation/README.md ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Text-to-Video Generation with MUGEN
2
+
3
+ This directory contains the high-level model components for text-to-video generation following [MUGEN](https://arxiv.org/abs/2204.08058). They demonstrate how to use building blocks from TorchMultimodal to quickly assemble a new auto-regressive generative model for different pairs of modalities. Here is a [colab demo](https://colab.research.google.com/drive/1C3ZbH_l19g_KqW3CPeX2-8Q2sOUCpmZo?usp=sharing) showing how to generate a video clip from text prompts.
4
+
5
+ https://user-images.githubusercontent.com/23155714/196074330-6f03593c-da8e-473f-8935-8bf1950baa33.mp4
6
+
7
+ ```python
8
+ from torchmultimodal.utils.generate import GenerationUtil
9
+ from examples.mugen.generation.text_video_gpt import text_video_gpt
10
+
11
+
12
+ model = text_video_gpt(video_seq_len=32, pretrained_text_video_gpt_model_key="mugen_L32")
13
+ generator = GenerationUtil(model)
14
+
15
+ output = generator.sample(
16
+ ['Mugen moves left to right on a cliff and picks up a gem.'],
17
+ max_seq_len=512,
18
+ use_cache=True,
19
+ causal=True,
20
+ device=<current_device>,
21
+ )
22
+ samples = output.decoded
23
+ ```
24
+
25
+ ## Model
26
+ The model architecture used by MUGEN follows [DALL-E](https://arxiv.org/abs/2102.12092) but with the image components replaced by those for video following [VideoGPT](https://arxiv.org/abs/2104.10157).
27
+
28
+ Multimodal generation involves generation of samples in one modality given inputs from another. As in the text-to-image generation model DALL-E, it typically involves a two-stage process of first learning a discrete latent representation for each modality and then using a [GPT](https://openai.com/blog/language-unsupervised/) transformer decoder to learn a joint prior for both modalities in the latent space. For text data, the latent representation is obtained through tokenization such as [BPE](https://en.wikipedia.org/wiki/Byte_pair_encoding) used in this example. For high dimensional data such as video and image, a [VQ-VAE](https://arxiv.org/abs/1711.00937) model is used to learn a set of downsampled discrete embedding vectors through nearest-neighbor lookups from a "codebook" where the chosen indices are referred to as the token ids following convention from language modeling.
29
+
30
+ VideoGPT is a generative model for video using a VQ-VAE model with video encoder/decoder and a GPT transformer decoder for token generation. The encoder and the decoder use 3D-convolution and self axial-attention to learn video information.
31
+
32
+ ## Generation
33
+ In this example generation refers to the auto-regressive process where we iteratively predict the next token id from the current until reaching the desired output length, a technique initially used by language modeling but has been extended to multimodal generation. To control the generation process, a top level abstraction is provided as a utility in [generate.py](https://github.com/facebookresearch/multimodal/blob/main/torchmultimodal/utils/generate.py) which takes the model as an input.
multimodal/examples/mugen/generation/text_video_gpt.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import List, Optional, Tuple
8
+
9
+ import torch
10
+
11
+ from examples.mugen.generation.video_vqvae import video_vqvae_mugen
12
+
13
+ from torch import nn, Tensor
14
+
15
+ from torchmultimodal.models.video_gpt.gpt import (
16
+ MultimodalGPT,
17
+ MultimodalTransformerDecoder,
18
+ RightShift,
19
+ TransformerDecoder,
20
+ TransformerDecoderLayer,
21
+ )
22
+ from torchmultimodal.modules.layers.attention import SelfAttention
23
+ from torchmultimodal.modules.layers.position_embedding import (
24
+ BroadcastedPositionEmbedding,
25
+ )
26
+ from torchmultimodal.utils.common import load_module_from_url
27
+ from torchtext.transforms import CharBPETokenizer
28
+
29
+
30
+ PRETRAINED_TOKENIZER_ENCODER_URL = "https://pytorch.s3.amazonaws.com/models/multimodal/mugen/tokenizer-coinrun_1024_encoder.json"
31
+ PRETRAINED_TOKENIZER_MERGES_URL = "https://pytorch.s3.amazonaws.com/models/multimodal/mugen/tokenizer-coinrun_1024_merges.txt"
32
+ PRETRAINED_TEXT_VIDEO_GPT_URL_MAPPING = {
33
+ "mugen_L32": "https://pytorch.s3.amazonaws.com/models/multimodal/mugen/text_video_gpt_L32_weights-17db9549.pth",
34
+ "mugen_L16": "https://pytorch.s3.amazonaws.com/models/multimodal/mugen/text_video_gpt_L16_weights-5dfc5a0a.pth",
35
+ "mugen_L8": "https://pytorch.s3.amazonaws.com/models/multimodal/mugen/text_video_gpt_L8_weights-72b6d2ab.pth",
36
+ }
37
+
38
+
39
+ def text_video_gpt(
40
+ text_seq_len: int = 128,
41
+ video_seq_len: int = 32,
42
+ resolution: int = 256,
43
+ downsample: Tuple[int, int, int] = (4, 32, 32),
44
+ d_model: int = 768,
45
+ n_head: int = 8,
46
+ dropout: float = 0.2,
47
+ attn_dropout: float = 0.3,
48
+ num_decoder_layers: int = 12,
49
+ use_gpt_init: bool = True,
50
+ pretrained_text_tokenizer_encoder_url: str = PRETRAINED_TOKENIZER_ENCODER_URL,
51
+ pretrained_text_tokenizer_merges_url: str = PRETRAINED_TOKENIZER_MERGES_URL,
52
+ pretrained_video_vqvae_model_key: Optional[str] = None,
53
+ pretrained_text_video_gpt_model_key: Optional[str] = None,
54
+ ) -> MultimodalGPT:
55
+ """Builds a text-to-video GPT model from user inputs
56
+
57
+ Parameter defaults follow MUGEN project:
58
+ * Video VQVAE: https://github.com/mugen-org/MUGEN_baseline/tree/main/generation/experiments/vqvae
59
+ * GPT: https://github.com/mugen-org/MUGEN_baseline/blob/main/lib/models/gpt/gpt.py#L252
60
+
61
+ Args:
62
+ text_seq_len (int): Length of text sequences after padding. Defaults to ``128``.
63
+ video_seq_len (int): Length of video sequences sampled from the dataset. Defaults to ``32``. Other
64
+ values used by MUGEN are ``8``, ``16``.
65
+ resolution (int): Resolution of the sampled video sequences defining height and width of each frame.
66
+ Defaults to ``256``.
67
+ downsample (Tuple[int, int, int]): Ratio by which to disperse along each dimension the sampled sequences.
68
+ For example, if the original frame is ``(32, 256, 256)``, after downsampling by ``(4, 32, 32)`` the
69
+ new frame will be of shape ``(8, 8, 8)`` with each dim divided by the rate of downsample. Defaults to
70
+ ``(4, 32, 32)``.
71
+ d_model (int): Dimension of the underlying transformer decoder.
72
+ See :py:class:`torchmultimodal.models.video_gpt.gpt.TransformerDecoderLayer`. Defaults to ``768``.
73
+ n_head (int): Number of attention heads used by the transformer decoder. Defaults to ``8``.
74
+ dropout (float): Dropout probability used by the projection layer of the transformer decoder.
75
+ Defaults to ``0.2``.
76
+ attn_dropout (float): Dropout probability used by the attention layer of the transformer decoder.
77
+ Defaults to ``0.3``.
78
+ num_decoder_layers (int): Number of transformer decoder layers. Defaults to ``12``.
79
+ use_gpt_init (bool): Whether uses parameter initialization of GPT model. Defaults to ``True``.
80
+ pretrained_text_tokenizer_encoder_url (str): Remote location of the pretrained text tokenizer encoder file.
81
+ Defaults to `"MUGEN pretrained tokenizer encoder file
82
+ "<https://pytorch.s3.amazonaws.com/models/multimodal/mugen/tokenizer-coinrun_1024_encoder.json>`_.
83
+ pretrained_text_tokenizer_merges_url (str): Remote location of the pretrained text tokenizer merges file.
84
+ Defaults to `"MUGEN pretrained tokenizer merges file
85
+ "<https://pytorch.s3.amazonaws.com/models/multimodal/mugen/tokenizer-coinrun_1024_merges.txt>`_.
86
+ pretrained_video_vqvae_model_key (str, optional): Key to select the pretrained MUGEN VideoVQVAE weights
87
+ file. For allowed values, see :py:module:`examples/mugen/generation/video_vqvae.py`.
88
+ Defaults to ``None``.
89
+ pretrained_text_video_gpt_model_key (str, optional): Key to select the pretrained MUGEN TextVideoGPT
90
+ weights file. The provided key should match that of MUGEN VideoVQVAE to ensure the two models were
91
+ pretrained for the same video sequence length. For example ``L32`` means the video sequence length
92
+ is ``32``. The loaded weights will override those from the frozen VideoVQVAE model.
93
+ Defaults to ``None``.
94
+
95
+ Returns:
96
+ An instance of :py:class:`torchmultimodal.models.video_gpt.gpt.MultimodalGPT`.
97
+ """
98
+
99
+ # builds text tokenizer from pre-trained
100
+ tokenizer = CharBPETokenizer(
101
+ bpe_encoder_path=pretrained_text_tokenizer_encoder_url,
102
+ bpe_merges_path=pretrained_text_tokenizer_merges_url,
103
+ unk_token="[UNK]",
104
+ special_tokens=["[PAD]", "[CLS]", "[SEP]", "[UNK]", "[MASK]"],
105
+ )
106
+
107
+ # builds text tokenizer
108
+ text_tokenizer = TextTokenizer(
109
+ context_len=text_seq_len,
110
+ d_model=d_model,
111
+ tokenizer=tokenizer,
112
+ )
113
+ num_text_tokens = text_tokenizer.num_text_tokens
114
+
115
+ # builds video tokenizer
116
+ video_vqvae = video_vqvae_mugen(
117
+ pretrained_model_key=pretrained_video_vqvae_model_key,
118
+ freeze_model=True,
119
+ )
120
+ video_vqvae.eval()
121
+ num_video_tokens = video_vqvae.num_embeddings # size of the codebook
122
+
123
+ # derives the expected latent shape from video input shape
124
+ video_input_shape = (video_seq_len, resolution, resolution)
125
+ video_latent_shape = latent_shape(video_input_shape, downsample)
126
+ video_vqvae_latent_shape = video_vqvae.latent_shape(video_input_shape)
127
+ # video vqvae will apply convolutions to the input shape which effectively
128
+ # reduces the size by ``dim//stride`` after each layer
129
+ # sanity check that the expected and actual latent shapes are consistent
130
+ if video_latent_shape != video_vqvae_latent_shape:
131
+ raise ValueError(
132
+ f"Latent shape derived from video inputs: {video_latent_shape} "
133
+ f"does not match that of video vqvae: {video_vqvae_latent_shape}"
134
+ )
135
+
136
+ # builds text embedding projection: text_emb is already of output shape `d_model`
137
+ # generally a projection layer is needed to bridge the tokenizer and
138
+ # `torchmultimodal.models.gpt.MultimodalTransformerDecoder`, see `video_projection`
139
+ text_projection = nn.Identity()
140
+
141
+ # builds video embedding projection
142
+ video_projection = nn.Linear(video_vqvae.embedding_dim, d_model, bias=False)
143
+
144
+ # builds multimodal decoder
145
+ text_pos_emb = nn.Embedding(text_seq_len, d_model)
146
+ video_pos_emb = BroadcastedPositionEmbedding(video_latent_shape, d_model)
147
+ attention_layer = SelfAttention(attn_dropout=attn_dropout)
148
+ decoder_layer = TransformerDecoderLayer(
149
+ d_model, n_head, dropout, attn_module=attention_layer
150
+ )
151
+ decoder = TransformerDecoder(decoder_layer, num_decoder_layers)
152
+ right_shift = RightShift(d_model)
153
+ mm_decoder = MultimodalTransformerDecoder(
154
+ text_pos_emb, video_pos_emb, decoder, right_shift
155
+ )
156
+
157
+ model = MultimodalGPT(
158
+ d_model=d_model,
159
+ num_in_tokens=num_text_tokens,
160
+ num_out_tokens=num_video_tokens,
161
+ latent_shape=video_latent_shape,
162
+ in_tokenizer=text_tokenizer,
163
+ out_tokenizer=video_vqvae,
164
+ mm_decoder=mm_decoder,
165
+ in_projection=text_projection,
166
+ out_projection=video_projection,
167
+ use_gpt_init=use_gpt_init,
168
+ )
169
+
170
+ if pretrained_text_video_gpt_model_key is not None:
171
+ if (
172
+ pretrained_text_video_gpt_model_key
173
+ not in PRETRAINED_TEXT_VIDEO_GPT_URL_MAPPING
174
+ ):
175
+ raise KeyError(
176
+ f"Invalid pretrained model key: {pretrained_text_video_gpt_model_key}"
177
+ )
178
+
179
+ load_module_from_url(
180
+ model,
181
+ PRETRAINED_TEXT_VIDEO_GPT_URL_MAPPING[pretrained_text_video_gpt_model_key],
182
+ )
183
+
184
+ return model
185
+
186
+
187
+ def latent_shape(
188
+ input_shape: Tuple[int, ...], downsample: Tuple[int, ...]
189
+ ) -> Tuple[int, ...]:
190
+ """Derives latent shape of video inputs after VQ-VAE encoding"""
191
+ return tuple([s // d for s, d in zip(input_shape, downsample)])
192
+
193
+
194
+ class TextTokenizer(nn.Module):
195
+ """Converts between text and tokens / embedings
196
+
197
+ Wrapper around the tokenizer to be consistent with the API required by
198
+ :py:class:`torchmultimodal.models.video_gpt.gpt.MultimodalGPT`. It also contains the
199
+ embedding layer to enable lookup by token ids.
200
+ """
201
+
202
+ def __init__(
203
+ self,
204
+ context_len: int,
205
+ d_model: int,
206
+ tokenizer: nn.Module,
207
+ ) -> None:
208
+ super().__init__()
209
+ self.tokenizer = tokenizer
210
+ self.pad_id = self.tokenizer.encode("[PAD]")[0] # type: ignore
211
+ self.vocab_size = self.tokenizer.vocab_size # type: ignore
212
+ self.context_len = context_len
213
+ # MUGEN treats padding as unique ids so adding them to the total text tokens
214
+ # https://github.com/mugen-org/MUGEN_baseline/blob/main/lib/models/gpt/gpt.py#L44
215
+ self.num_text_tokens = self.vocab_size + context_len
216
+ self.embedding = nn.Embedding(self.num_text_tokens, d_model)
217
+
218
+ def text_to_tokens(self, sentences: List[str]) -> Tensor:
219
+ """Pads the sentences to be of equal lengths"""
220
+ tokens = [
221
+ self.tokenizer.encode(sentence.strip().lower() + " [SEP]") # type: ignore
222
+ for sentence in sentences
223
+ ]
224
+ token_ids = [t[: self.context_len] for t in tokens]
225
+ # pad each sentence to be of length `context_len`
226
+ for i, t in enumerate(token_ids):
227
+ t += [self.pad_id] * (self.context_len - len(t))
228
+ token_ids[i] = t
229
+
230
+ return torch.Tensor(token_ids).type(torch.int64)
231
+
232
+ def encode(self, sentences: List[str], device: str) -> Tensor:
233
+ """Encodes sentences to token ids"""
234
+ token_ids = self.text_to_tokens(sentences).to(device)
235
+ # bump padding token ids by vocab_size so that they do not coincide with un-padded token ids
236
+ # and that the padding token ids themselves are unique
237
+ unique_pad_ids = torch.arange(self.context_len, device=device) + self.vocab_size
238
+ token_ids = torch.where(token_ids == self.pad_id, unique_pad_ids, token_ids)
239
+ return token_ids
240
+
241
+ def _filter_token_ids(self, token_ids: List[int]) -> List[Optional[int]]:
242
+ """Filters out token ids out side of vocab"""
243
+ return [
244
+ token_id
245
+ for token_id in token_ids
246
+ if token_id > 0 and token_id <= self.vocab_size
247
+ ]
248
+
249
+ def decode(self, token_ids: Tensor) -> List[str]:
250
+ """Decodes token ids back to sentences"""
251
+ sentences = []
252
+ for _token_ids in token_ids: # iterate over batches
253
+ _token_ids = self._filter_token_ids(_token_ids.tolist())
254
+ sentence = self.tokenizer.decode(_token_ids) # type: ignore
255
+ sentences.append(sentence)
256
+
257
+ return sentences
258
+
259
+ def lookup(self, token_ids: Tensor) -> Tensor:
260
+ return self.embedding(token_ids)
multimodal/examples/mugen/generation/video_vqvae.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Optional
8
+
9
+ from torchmultimodal.models.video_gpt.video_vqvae import (
10
+ preprocess_int_conv_params,
11
+ VideoDecoder,
12
+ VideoEncoder,
13
+ )
14
+
15
+ from torchmultimodal.models.vqvae import VQVAE
16
+ from torchmultimodal.utils.common import load_module_from_url, remove_grad
17
+
18
+
19
+ MUGEN_PRETRAINED_MAPPING = {
20
+ "mugen_L32": "https://download.pytorch.org/models/multimodal/mugen/mugen_video_vqvae_L32.pt",
21
+ "mugen_L16": "https://download.pytorch.org/models/multimodal/mugen/mugen_video_vqvae_L16.pt",
22
+ "mugen_L8": "https://download.pytorch.org/models/multimodal/mugen/mugen_video_vqvae_L8.pt",
23
+ }
24
+
25
+
26
+ def video_vqvae_mugen(
27
+ in_channel_dim: int = 3,
28
+ encoder_hidden_dim: int = 240,
29
+ encoder_kernel_size: int = 3,
30
+ n_res_layers: int = 4,
31
+ attn_hidden_dim: int = 240,
32
+ num_embeddings: int = 2048,
33
+ embedding_dim: int = 256,
34
+ decoder_hidden_dim: int = 240,
35
+ decoder_kernel_size: int = 3,
36
+ pretrained_model_key: Optional[str] = None,
37
+ freeze_model: bool = False,
38
+ ) -> VQVAE:
39
+ """Constructor for MUGEN's Video VQVAE. Expects input video data of shape ``{8,16,32}x256x256``.
40
+ Trained for tokenization of video data and use in video-audio-text retrieval and generation tasks.
41
+ See Hayes et al. 2022 for more details: https://arxiv.org/pdf/2204.08058.pdf
42
+ Code ref:
43
+ https://github.com/mugen-org/MUGEN_baseline/blob/main/lib/models/video_vqvae/vqvae.py
44
+ https://github.com/mugen-org/MUGEN_baseline/blob/main/generation/experiments/vqvae/VideoVQVAE_L32.sh
45
+
46
+ Args:
47
+ in_channel_dim (int, optional): Size of channel dim in input. Defaults to ``3``.
48
+ encoder_hidden_dim (int, optional): Size of channel dims in encoder conv layers. Defaults to ``240``.
49
+ encoder_kernel_size (int, optional): Kernel size for encoder. Defaults to ``3``.
50
+ n_res_layers (int, optional): Number of ``AttentionResidualBlocks`` to include in encoder and decoder.
51
+ Defaults to ``4``.
52
+ attn_hidden_dim (int, optional): Size of hidden dim of
53
+ :class:`~torchmultimodal.models.video_gpt.video_vqvae.AttentionResidualBlocks`. Defaults to ``240``.
54
+ num_embeddings (int, optional): Number of embedding vectors used in
55
+ :class:`~torchmultimodal.modules.layers.codebook.Codebook`. Defaults to ``2048``.
56
+ embedding_dim (int, optional): Dimensionality of embedding vectors in
57
+ :class:`~torchmultimodal.modules.layers.codebook.Codebook`. Defaults to ``256``.
58
+ decoder_hidden_dim (int, optional): Size of channel dims in decoder conv tranpose layers.
59
+ Defaults to ``240``.
60
+ decoder_kernel_size (int, optional): Kernel size for decoder. Defaults to ``3``.
61
+ pretrained_model_key (str, optional): Load a specified MUGEN VQVAE checkpoint.
62
+ freeze_model (bool): Whether to freeze the weights of the pretrained model. Defaults to ``False``.
63
+
64
+ Returns:
65
+ An instance of :class:`~torchmultimodal.models.vqvae.VQVAE` constructed with:
66
+ * :class:`~torchmultimodal.model.video_gpt.video_vqvae.VideoEncoder`
67
+ * :class:`~torchmultimodal.model.video_gpt.video_vqvae.VideoDecoder`
68
+ """
69
+ encoder_strides = ((2, 2, 2), (2, 2, 2), (1, 2, 2), (1, 2, 2), (1, 2, 2), (1, 1, 1))
70
+ decoder_strides = ((2, 2, 2), (2, 2, 2), (1, 2, 2), (1, 2, 2), (1, 2, 2))
71
+ encoder_n_layers = len(encoder_strides)
72
+ decoder_n_layers = len(decoder_strides)
73
+ encoder_in_channel_dims = (in_channel_dim,) + (encoder_hidden_dim,) * max(
74
+ encoder_n_layers - 1, 0
75
+ )
76
+ decoder_out_channel_dims = (decoder_hidden_dim,) * max(decoder_n_layers - 1, 0) + (
77
+ in_channel_dim,
78
+ )
79
+ encoder_kernel_sizes_fixed = preprocess_int_conv_params(
80
+ encoder_in_channel_dims, encoder_kernel_size
81
+ )
82
+ decoder_kernel_sizes_fixed = preprocess_int_conv_params(
83
+ decoder_out_channel_dims, decoder_kernel_size
84
+ )
85
+
86
+ encoder = VideoEncoder(
87
+ encoder_in_channel_dims,
88
+ encoder_kernel_sizes_fixed,
89
+ encoder_strides,
90
+ embedding_dim,
91
+ n_res_layers,
92
+ attn_hidden_dim,
93
+ )
94
+ decoder = VideoDecoder(
95
+ decoder_out_channel_dims,
96
+ decoder_kernel_sizes_fixed,
97
+ decoder_strides,
98
+ embedding_dim,
99
+ n_res_layers,
100
+ attn_hidden_dim,
101
+ )
102
+ model = VQVAE(encoder, decoder, num_embeddings, embedding_dim)
103
+
104
+ if pretrained_model_key is not None:
105
+ if pretrained_model_key not in MUGEN_PRETRAINED_MAPPING.keys():
106
+ raise KeyError(f"Invalid pretrained model key: {pretrained_model_key}")
107
+
108
+ load_module_from_url(model, MUGEN_PRETRAINED_MAPPING[pretrained_model_key])
109
+
110
+ if freeze_model:
111
+ remove_grad(model)
112
+
113
+ return model
multimodal/examples/mugen/retrieval/README.md ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MUGEN Retrieval
2
+
3
+ This directory contains reference training and evaluation scripts for MUGEN's video-text retrieval model, including a tutorial notebook for the model usage [Colab](https://colab.research.google.com/drive/1gZfz1jsy79CNCK9t2_r43yt3z7v-w4HS?usp=sharing) or [GitHub](https://github.com/facebookresearch/multimodal/blob/main/examples/mugen/retrieval/tutorial.ipynb).
4
+
5
+ ## Model
6
+ MUGEN's video-text retrieval model follows from [VideoCLIP](https://arxiv.org/abs/2109.14084), a contrastive model for video and text.
7
+
8
+ The name "VideoCLIP" refers to its similarities to OpenAI's [CLIP](https://arxiv.org/abs/2103.00020), which was originally proposed for zero-shot learning of image classification tasks by “drawing cues” from text data with the corresponding visual concepts. Unlike various predecessor models based on supervised learning, CLIP does not have to be trained on the task-specific datasets or fine-tuned with a task-specific head. The model learns a joint embedding space for both image and text data and optimizes a scaled cosine similarity function between the image and text embedding vectors. The loss function is the sum of the normalized cosine similarities for every pair of image-and-text samples. Each embedding is trained with a unimodal encoder, e.g., a transformer for text, vision transformer (ViT) or ResNet for image.
9
+
10
+ The VideoCLIP model follows the CLIP architecture but replaces the image encoder with a video encoder. VideoCLIP's video encoder is backed by [Separable 3D CNN (S3D)](https://arxiv.org/abs/1712.04851), a video classification model, and the text encoder is backed by [DistilBERT](https://arxiv.org/abs/1910.01108), a lightweight transformer for language modeling.
11
+
12
+ ## Training
13
+ The configurable parameters for training can be found in `configs/train.yaml`. Note that the training script supports training on 1 or more devices on a single node. Then run the following command:
14
+ ```
15
+ python train.py config=configs/train.yaml
16
+ ```
17
+ A checkpoint file with the best-performing weights will be saved under `{default_root_dir}/lightning_logs/`, where `default_root_dir` is specified in the training config. If `default_root_dir` is `null`, then it will act as your working directory.
18
+
19
+ ## Evaluation
20
+ The configurable parameters for evaluation can be found in `configs/eval.yaml`. You can choose to replace `checkpoint_path` with the path to your checkpoint from the training step, or keep the default `checkpoint_path` to load the MUGEN authors' weights (fit to our implementation). Then run the following command:
21
+ ```
22
+ python eval.py config=configs/eval.yaml
23
+ ```
24
+
25
+ Using the default arguments in `configs/eval.yaml` (including the MUGEN authors' published weights), we ran the evaluation script on the full MUGEN test set and got the following results:
26
+
27
+ | Metric (%) | MUGEN Results | TorchMultimodal Results |
28
+ | ----------- | ----------- | ----------- |
29
+ | Text2video top-1 recall | 8.54 | 8.26 |
30
+ | Text2video top-5 recall | 22.50 | 22.34 |
31
+ | Text2video top-10 recall | 31.71 | 31.68 |
32
+ | Video2text top-1 recall | 10.61 | 10.79 |
33
+ | Video2text top-5 recall | 25.72 | 25.70 |
34
+ | Video2text top-10 recall | 34.70 | 34.60 |
multimodal/examples/mugen/retrieval/configs/eval.yaml ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: examples.mugen.retrieval.definitions.EvaluationArgs
2
+ dataset_args:
3
+ _target_: examples.mugen.data.mugen_dataset.MUGENDatasetArgs
4
+ data_path: "datasets/coinrun/coinrun_dataset_jsons/release"
5
+ asset_path: "datasets/coinrun/assets"
6
+ sample_every_n_frames: 3
7
+ sequence_length: 32
8
+ audio_sample_rate: 22050
9
+ audio_sample_length: 70560
10
+ resolution: 256
11
+ bbox_smap_for_agent: False
12
+ bbox_smap_for_monsters: False
13
+ use_manual_annotation: True
14
+ use_auto_annotation: False
15
+ use_downsampled_trainset: False
16
+ fixed_start_idx: False
17
+ get_game_frame: True
18
+ get_seg_map: False
19
+ get_text_desc: True
20
+ get_audio: False
21
+ debug: False
22
+ datamodule_args:
23
+ _target_: examples.mugen.retrieval.definitions.DataModuleArgs
24
+ batch_size: 16
25
+ num_workers: 4
26
+ shuffle: False
27
+ bert_text_transform:
28
+ _target_: examples.mugen.retrieval.definitions.BertTextTransformArgs
29
+ video_transform:
30
+ _target_: examples.mugen.retrieval.definitions.VideoTransformArgs
31
+ lightningmodule_args:
32
+ _target_: examples.mugen.retrieval.definitions.LightningModuleArgs
33
+ logit_scale: 0.07
34
+ logit_scale_max: 100.0
35
+ videoclip_args:
36
+ _target_: examples.mugen.retrieval.definitions.VideoCLIPArgs
37
+ text_pretrained: False
38
+ text_trainable: False
39
+ text_model_name: "distilbert-base-uncased"
40
+ text_model_config: null
41
+ text_padding_value: 0
42
+ video_pretrained: False
43
+ video_trainable: False
44
+ video_pretrain_path: "https://pytorch.s3.amazonaws.com/models/multimodal/mugen/S3D_kinetics400.pt"
45
+ proj_out_dim: 256
46
+ proj_dropout: 0.1
47
+ checkpoint_path: "https://pytorch.s3.amazonaws.com/models/multimodal/mugen/videoclip_lightning_mugen.pt"
48
+ accelerator: "auto"
multimodal/examples/mugen/retrieval/configs/train.yaml ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: examples.mugen.retrieval.definitions.TrainingArgs
2
+ dataset_args:
3
+ _target_: examples.mugen.data.mugen_dataset.MUGENDatasetArgs
4
+ data_path: "datasets/coinrun/coinrun_dataset_jsons/release"
5
+ asset_path: "datasets/coinrun/assets"
6
+ sample_every_n_frames: 3
7
+ sequence_length: 32
8
+ audio_sample_rate: 22050
9
+ audio_sample_length: 70560
10
+ resolution: 224
11
+ bbox_smap_for_agent: False
12
+ bbox_smap_for_monsters: False
13
+ use_manual_annotation: True
14
+ use_auto_annotation: False
15
+ use_downsampled_trainset: False
16
+ fixed_start_idx: False
17
+ get_game_frame: True
18
+ get_seg_map: False
19
+ get_text_desc: True
20
+ get_audio: False
21
+ debug: False
22
+ datamodule_args:
23
+ _target_: examples.mugen.retrieval.definitions.DataModuleArgs
24
+ batch_size: 16
25
+ num_workers: 4
26
+ shuffle: False
27
+ bert_text_transform:
28
+ _target_: examples.mugen.retrieval.definitions.BertTextTransformArgs
29
+ video_transform:
30
+ _target_: examples.mugen.retrieval.definitions.VideoTransformArgs
31
+ lightningmodule_args:
32
+ _target_: examples.mugen.retrieval.definitions.LightningModuleArgs
33
+ logit_scale: 0.07
34
+ logit_scale_max: 100.0
35
+ learning_rate: 0.001
36
+ weight_decay: 0.001
37
+ videoclip_args:
38
+ _target_: examples.mugen.retrieval.definitions.VideoCLIPArgs
39
+ text_pretrained: True
40
+ text_trainable: False
41
+ text_model_name: "distilbert-base-uncased"
42
+ text_model_config: null
43
+ text_padding_value: 0
44
+ video_pretrained: True
45
+ video_trainable: True
46
+ video_pretrain_path: "https://pytorch.s3.amazonaws.com/models/multimodal/mugen/S3D_kinetics400.pt"
47
+ proj_out_dim: 256
48
+ proj_dropout: 0.1
49
+ accelerator: "auto"
50
+ devices: 4
51
+ max_epochs: 20
52
+ log_every_n_steps: 100
53
+ default_root_dir: null
multimodal/examples/mugen/retrieval/definitions.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from dataclasses import dataclass
8
+ from typing import Any, Dict, Optional, Tuple
9
+
10
+ from examples.mugen.data.mugen_dataset import MUGENDatasetArgs
11
+
12
+ from torchmultimodal.transforms.video_transform import (
13
+ DEFAULT_MEAN,
14
+ DEFAULT_RESIZE_SHAPE,
15
+ DEFAULT_STD,
16
+ MUGEN_DEFAULT_TIME_SAMPLES,
17
+ )
18
+
19
+
20
+ @dataclass
21
+ class BertTextTransformArgs:
22
+ vocab_file: str = "https://huggingface.co/bert-base-uncased/resolve/main/vocab.txt"
23
+ do_lower_case: bool = True
24
+ start_token: int = 101
25
+ end_token: int = 102
26
+ padding_value: int = 0
27
+
28
+
29
+ @dataclass
30
+ class VideoTransformArgs:
31
+ time_samples: int = MUGEN_DEFAULT_TIME_SAMPLES
32
+ mean: Tuple[float] = DEFAULT_MEAN
33
+ std: Tuple[float] = DEFAULT_STD
34
+ resize_shape: Tuple[int, int] = DEFAULT_RESIZE_SHAPE
35
+
36
+
37
+ @dataclass
38
+ class DataModuleArgs:
39
+ batch_size: int = 16
40
+ num_workers: int = 4
41
+ shuffle: bool = False
42
+ bert_text_transform: BertTextTransformArgs = BertTextTransformArgs()
43
+ video_transform: VideoTransformArgs = VideoTransformArgs()
44
+
45
+
46
+ @dataclass
47
+ class LightningModuleArgs:
48
+ logit_scale: float = 0.07
49
+ logit_scale_max: float = 100.0
50
+ learning_rate: float = 1e-3
51
+ weight_decay: float = 1e-3
52
+ recall_ks: Tuple[int] = (1, 5, 10)
53
+
54
+
55
+ @dataclass
56
+ class VideoCLIPArgs:
57
+ text_pretrained: bool = False
58
+ text_trainable: bool = False
59
+ text_model_name: str = "distilbert-base-uncased"
60
+ text_model_config: Optional[Dict[str, Any]] = None
61
+ text_padding_value: int = 0
62
+ video_pretrained: bool = False
63
+ video_trainable: bool = False
64
+ video_pretrain_path: str = (
65
+ "https://pytorch.s3.amazonaws.com/models/multimodal/mugen/S3D_kinetics400.pt"
66
+ )
67
+ proj_out_dim: int = 256
68
+ proj_dropout: float = 0.1
69
+
70
+
71
+ @dataclass
72
+ class EvaluationArgs:
73
+ dataset_args: MUGENDatasetArgs = MUGENDatasetArgs(
74
+ get_game_frame=True,
75
+ get_text_desc=True,
76
+ resolution=256,
77
+ fixed_start_idx=False,
78
+ use_manual_annotation=True,
79
+ use_auto_annotation=False,
80
+ )
81
+ datamodule_args: DataModuleArgs = DataModuleArgs()
82
+ lightningmodule_args: LightningModuleArgs = LightningModuleArgs()
83
+ videoclip_args: VideoCLIPArgs = VideoCLIPArgs()
84
+ checkpoint_path: str = "https://pytorch.s3.amazonaws.com/models/multimodal/mugen/videoclip_lightning_mugen.pt"
85
+ accelerator: str = "auto"
86
+
87
+
88
+ @dataclass
89
+ class TrainingArgs:
90
+ dataset_args: MUGENDatasetArgs = MUGENDatasetArgs(
91
+ get_game_frame=True,
92
+ get_text_desc=True,
93
+ resolution=224,
94
+ fixed_start_idx=False,
95
+ use_manual_annotation=True,
96
+ use_auto_annotation=False,
97
+ )
98
+ datamodule_args: DataModuleArgs = DataModuleArgs()
99
+ lightningmodule_args: LightningModuleArgs = LightningModuleArgs()
100
+ videoclip_args: VideoCLIPArgs = VideoCLIPArgs()
101
+ accelerator: str = "auto"
102
+ devices: int = 4
103
+ max_epochs: int = 1000
104
+ log_every_n_steps: int = 100
105
+ default_root_dir: Optional[str] = None
multimodal/examples/mugen/retrieval/eval.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from examples.mugen.data.bert_text_transform import BertTextTransform
8
+ from examples.mugen.data.mugen_datamodules import MUGENDataModule
9
+ from examples.mugen.data.mugen_dataset import MUGENDatasetArgs
10
+ from examples.mugen.retrieval.model import VideoCLIPLightningModule
11
+ from hydra.utils import instantiate
12
+ from omegaconf import OmegaConf
13
+ from pytorch_lightning import Trainer
14
+ from torchmultimodal.transforms.video_transform import VideoTransform
15
+
16
+
17
+ def get_yaml_config():
18
+ cli_conf = OmegaConf.from_cli()
19
+ if "config" not in cli_conf:
20
+ raise ValueError(
21
+ "Please pass 'config' to specify configuration yaml file for running VideoCLIP evaluation"
22
+ )
23
+ yaml_conf = OmegaConf.load(cli_conf.config)
24
+ conf = instantiate(yaml_conf)
25
+ return conf
26
+
27
+
28
+ def evaluate():
29
+ args = get_yaml_config()
30
+
31
+ dataset_args: MUGENDatasetArgs = args.dataset_args
32
+ datamodule = MUGENDataModule(
33
+ dataset_args,
34
+ text_transform=BertTextTransform(
35
+ **vars(args.datamodule_args.bert_text_transform)
36
+ ),
37
+ video_transform=VideoTransform(**vars(args.datamodule_args.video_transform)),
38
+ batch_size=args.datamodule_args.batch_size,
39
+ num_workers=args.datamodule_args.num_workers,
40
+ shuffle=args.datamodule_args.shuffle,
41
+ )
42
+
43
+ model = VideoCLIPLightningModule.load_from_checkpoint(
44
+ args.checkpoint_path,
45
+ **vars(args.lightningmodule_args),
46
+ **vars(args.videoclip_args),
47
+ )
48
+
49
+ trainer = Trainer(accelerator=args.accelerator, devices=1)
50
+ trainer.test(model, dataloaders=datamodule.test_dataloader())
51
+
52
+
53
+ if __name__ == "__main__":
54
+ evaluate()
multimodal/examples/mugen/retrieval/model.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import warnings
8
+ from typing import Any, Tuple
9
+
10
+ import torch
11
+
12
+ from examples.mugen.retrieval.video_clip import videoclip
13
+ from pytorch_lightning import LightningModule
14
+ from torchmetrics import Recall
15
+
16
+ from torchmultimodal.modules.losses.contrastive_loss_with_temperature import (
17
+ ContrastiveLossWithTemperature,
18
+ )
19
+
20
+
21
+ class VideoCLIPLightningModule(LightningModule):
22
+ """PyTorch Lightning module for evaluating VideoCLIP model.
23
+ Args:
24
+ logit_scale (float): Initial log-temperature value for contrastive loss funtion.
25
+ Defaults to ``0.07``, MUGEN's log-temperature value at initialization.
26
+ logit_scale_max (float): Maximum log-temperature value for contrastive loss function.
27
+ Defaults to ``100``, MUGEN's maximum log-temperature value.
28
+ learning_rate (float): optimizer learning rate.
29
+ Defaults to ``1e-3``, MUGEN's learning rate.
30
+ weight_decay (float): optimizer weight decay.
31
+ Defaults to ``1e-3``, MUGEN's weight decay.
32
+ recall_ks (Tuple[int]): tuple of top-``k``'s for calculating recall.
33
+ Defaults to ``(1, 5, 10)``, i.e. top-1 recall, top-5 recall, and top-10 recall.
34
+ **videoclip_kwargs (Any): Keyword arguments for the videoCLIP model builder.
35
+ """
36
+
37
+ def __init__(
38
+ self,
39
+ logit_scale: float = 0.07,
40
+ logit_scale_max: float = 100,
41
+ learning_rate: float = 1e-3,
42
+ weight_decay: float = 1e-3,
43
+ recall_ks: Tuple[int] = (1, 5, 10),
44
+ **videoclip_kwargs: Any,
45
+ ):
46
+ super().__init__()
47
+ self.model = videoclip(**videoclip_kwargs)
48
+ self.contrastive_loss = ContrastiveLossWithTemperature(
49
+ logit_scale=logit_scale,
50
+ logit_scale_min=None,
51
+ logit_scale_max=logit_scale_max,
52
+ )
53
+ self.lr = learning_rate
54
+ self.weight_decay = weight_decay
55
+
56
+ self.recall_ks = set(recall_ks)
57
+ if len(self.recall_ks) != len(recall_ks):
58
+ warnings.warn("Duplicate `k` values in `recall_ks` are ignored.")
59
+ self.metrics = torch.nn.ModuleDict()
60
+ for k in self.recall_ks:
61
+ self.metrics.update(
62
+ {f"v2t_recall_{k}": Recall(top_k=k), f"t2v_recall_{k}": Recall(top_k=k)}
63
+ )
64
+
65
+ def _collect_embeddings(self, outputs):
66
+ text_embeddings = [batch.embeddings_a for batch in outputs]
67
+ video_embeddings = [batch.embeddings_b for batch in outputs]
68
+
69
+ embeddings = {
70
+ "text": torch.cat(text_embeddings),
71
+ "video": torch.cat(video_embeddings),
72
+ }
73
+ return embeddings
74
+
75
+ def _compute_recall(self, split, text_embedding, video_embedding):
76
+ similarity_matrix = text_embedding @ video_embedding.T
77
+ num_samples = similarity_matrix.shape[0]
78
+ target_matrix = torch.eye(
79
+ n=num_samples, dtype=int, device=similarity_matrix.device
80
+ )
81
+
82
+ for k in self.recall_ks:
83
+ v2t_recall = self.metrics[f"v2t_recall_{k}"]
84
+ v2t_recall(preds=similarity_matrix.T, target=target_matrix)
85
+ self.log(f"{split}/Recall@{k} (video query, text retrieval)", v2t_recall)
86
+
87
+ t2v_recall = self.metrics[f"t2v_recall_{k}"]
88
+ t2v_recall(preds=similarity_matrix, target=target_matrix)
89
+ self.log(f"{split}/Recall@{k} (text query, video retrieval)", t2v_recall)
90
+
91
+ def configure_optimizers(self):
92
+ params = self.parameters()
93
+ optimizer = torch.optim.AdamW(
94
+ params, lr=self.lr, weight_decay=self.weight_decay
95
+ )
96
+ return optimizer
97
+
98
+ def training_step(self, batch, batch_idx):
99
+ text, video = batch.get("text"), batch.get("video")
100
+ model_output = self.model(features_a=text, features_b=video)
101
+ loss = self.contrastive_loss(
102
+ model_output.embeddings_a, model_output.embeddings_b
103
+ )
104
+ self.log(
105
+ "train/loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True
106
+ )
107
+ return {"loss": loss, "model_output": model_output}
108
+
109
+ def validation_step(self, batch, batch_idx):
110
+ text, video = batch.get("text"), batch.get("video")
111
+ model_output = self.model(features_a=text, features_b=video)
112
+ loss = self.contrastive_loss(
113
+ model_output.embeddings_a, model_output.embeddings_b
114
+ )
115
+ self.log(
116
+ "validation/loss",
117
+ loss,
118
+ on_step=True,
119
+ on_epoch=True,
120
+ prog_bar=True,
121
+ logger=True,
122
+ )
123
+ return {"loss": loss, "model_output": model_output}
124
+
125
+ def validation_epoch_end(self, outputs):
126
+ model_outputs = [batch["model_output"] for batch in outputs]
127
+ all_embeddings = self._collect_embeddings(model_outputs)
128
+ text_embedding, video_embedding = (
129
+ all_embeddings["text"],
130
+ all_embeddings["video"],
131
+ )
132
+ self._compute_recall("validation", text_embedding, video_embedding)
133
+
134
+ def test_step(self, batch, batch_idx):
135
+ text, video = batch.get("text"), batch.get("video")
136
+ model_output = self.model(features_a=text, features_b=video)
137
+ return model_output
138
+
139
+ def test_epoch_end(self, outputs):
140
+ all_embeddings = self._collect_embeddings(outputs)
141
+ text_embedding, video_embedding = (
142
+ all_embeddings["text"],
143
+ all_embeddings["video"],
144
+ )
145
+ self._compute_recall("test", text_embedding, video_embedding)
multimodal/examples/mugen/retrieval/train.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from examples.mugen.data.bert_text_transform import BertTextTransform
8
+ from examples.mugen.data.mugen_datamodules import MUGENDataModule
9
+ from examples.mugen.data.mugen_dataset import MUGENDatasetArgs
10
+ from examples.mugen.retrieval.model import VideoCLIPLightningModule
11
+ from hydra.utils import instantiate
12
+ from omegaconf import OmegaConf
13
+ from pytorch_lightning import Trainer
14
+ from pytorch_lightning.callbacks import ModelCheckpoint
15
+ from torchmultimodal.transforms.video_transform import VideoTransform
16
+
17
+
18
+ def get_yaml_config():
19
+ cli_conf = OmegaConf.from_cli()
20
+ if "config" not in cli_conf:
21
+ raise ValueError(
22
+ "Please pass 'config' to specify configuration yaml file for running VideoCLIP training"
23
+ )
24
+ yaml_conf = OmegaConf.load(cli_conf.config)
25
+ conf = instantiate(yaml_conf)
26
+ return conf
27
+
28
+
29
+ def train():
30
+ args = get_yaml_config()
31
+
32
+ dataset_args: MUGENDatasetArgs = args.dataset_args
33
+ datamodule = MUGENDataModule(
34
+ dataset_args,
35
+ text_transform=BertTextTransform(
36
+ **vars(args.datamodule_args.bert_text_transform)
37
+ ),
38
+ video_transform=VideoTransform(**vars(args.datamodule_args.video_transform)),
39
+ batch_size=args.datamodule_args.batch_size,
40
+ num_workers=args.datamodule_args.num_workers,
41
+ shuffle=args.datamodule_args.shuffle,
42
+ )
43
+
44
+ model = VideoCLIPLightningModule(
45
+ **vars(args.lightningmodule_args),
46
+ **vars(args.videoclip_args),
47
+ )
48
+
49
+ checkpoint_callback = ModelCheckpoint(save_top_k=-1)
50
+ trainer = Trainer(
51
+ accelerator=args.accelerator,
52
+ devices=args.devices,
53
+ strategy="ddp_find_unused_parameters_false",
54
+ max_epochs=args.max_epochs,
55
+ log_every_n_steps=args.log_every_n_steps,
56
+ default_root_dir=args.default_root_dir,
57
+ callbacks=[checkpoint_callback],
58
+ )
59
+ trainer.fit(
60
+ model=model,
61
+ train_dataloaders=datamodule.train_dataloader(),
62
+ val_dataloaders=datamodule.val_dataloader(),
63
+ )
64
+
65
+
66
+ if __name__ == "__main__":
67
+ train()