jaketae commited on
Commit
6b9773c
·
1 Parent(s): 28babb8

feature: add coco_only model ckpt

Browse files
.gitattributes CHANGED
File without changes
.gitignore CHANGED
File without changes
Makefile CHANGED
File without changes
README.md CHANGED
@@ -1,6 +1,6 @@
1
- # [WIP] Korean CLIP Model
2
 
3
- Korean version of CLIP model. We are using Klue text model with image-text data pairs.
4
 
5
  ## Installation
6
 
 
1
+ # KoCLIP
2
 
3
+ This repository includes
4
 
5
  ## Installation
6
 
configuration_hybrid_clip.py CHANGED
File without changes
dataloader.py CHANGED
@@ -53,7 +53,7 @@ class ImageTextDataset(VisionDataset):
53
  self,
54
  root: str,
55
  file_path: str,
56
- captions_per_image=2,
57
  transform: Optional[Callable] = None,
58
  target_transform: Optional[Callable] = None,
59
  transforms: Optional[Callable] = None,
@@ -61,7 +61,7 @@ class ImageTextDataset(VisionDataset):
61
  super().__init__(root, transforms, transform, target_transform)
62
 
63
  with open(file_path, "r") as f:
64
- examples = [json.loads(line) for line in f.readlines()]
65
 
66
  self.captions = []
67
  self.image_paths = []
@@ -69,7 +69,7 @@ class ImageTextDataset(VisionDataset):
69
  for example in examples:
70
  captions = example["captions"][:captions_per_image]
71
  self.captions.extend(captions)
72
- self.image_paths.extend([example["image_path"]] * len(captions))
73
 
74
  def _load_image(self, idx: int):
75
  path = self.image_paths[idx]
 
53
  self,
54
  root: str,
55
  file_path: str,
56
+ captions_per_image=5,
57
  transform: Optional[Callable] = None,
58
  target_transform: Optional[Callable] = None,
59
  transforms: Optional[Callable] = None,
 
61
  super().__init__(root, transforms, transform, target_transform)
62
 
63
  with open(file_path, "r") as f:
64
+ examples = json.load(f)
65
 
66
  self.captions = []
67
  self.image_paths = []
 
69
  for example in examples:
70
  captions = example["captions"][:captions_per_image]
71
  self.captions.extend(captions)
72
+ self.image_paths.extend([example["file_path"]] * len(captions))
73
 
74
  def _load_image(self, idx: int):
75
  path = self.image_paths[idx]
down_wit.py DELETED
@@ -1,79 +0,0 @@
1
-
2
- import csv
3
- import glob
4
- from typing import Text, List
5
- import urllib.request
6
- import requests
7
- from multiprocessing import Pool
8
- import socket
9
- timeout = 10
10
- socket.setdefaulttimeout(timeout)
11
-
12
-
13
- DATA_PATH='/home/shared/dataset/wit'
14
- # DATA_PATH='../data/wit'
15
-
16
-
17
- def load_file(path):
18
- """
19
- load csv
20
- """
21
- with open(path) as f:
22
- reader = csv.reader(f, delimiter='\t', quotechar='"')
23
- data = list(reader)
24
- return data
25
-
26
-
27
- def extract_ko(data):
28
- """
29
- Extract lang=ko data samples
30
- """
31
- trainset = []
32
- for samp in data[1:]:
33
- if samp[0] != 'ko':
34
- continue
35
- trainset.append(samp)
36
- return trainset
37
-
38
-
39
- def rewrite_wit(data_paths):
40
- """
41
- we need only korean set. extract only korean set.
42
- https://drive.google.com/file/d/1y_DxYrmUF4vw3m7UOlVsHSkcO_v0XuLv/view?usp=sharing
43
-
44
- """
45
- samples = []
46
- for path in data_paths:
47
- data = load_file(path)
48
- samples += extract_ko(data)
49
- return [[i, *samp] for i, samp in enumerate(samples)]
50
-
51
-
52
- def req_imgs(url_info):
53
- """ download imgs """
54
- # request.get 요청
55
- response = requests.get(url_info[1], headers={'User-agent': 'your bot 0.1'})
56
- with open(f'{DATA_PATH}/img/{url_info[0]}.jpg', 'wb') as f:
57
- f.write(response.content)
58
- # print(f"{url_info[0]} is done.")
59
-
60
-
61
- def down_imgs(urls):
62
- with Pool(2) as p:
63
- p.map(req_imgs, urls)
64
-
65
-
66
- if __name__ == '__main__':
67
- # path_list = glob.glob('/home/shared/dataset/wit')
68
- path_list = glob.glob(f'{DATA_PATH}/info/*')
69
-
70
- samples = rewrite_wit(path_list)
71
-
72
- with open(f'{DATA_PATH}/wit_ko.csv', 'w') as f:
73
- writer = csv.writer(f, delimiter='\t', quotechar='"')
74
- writer.writerows(samples)
75
-
76
- url_list = [[samp[0], samp[3]] for samp in samples]
77
- down_imgs(url_list)
78
-
79
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modeling_hybrid_clip.py CHANGED
File without changes
models/coco_only/config.json ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "HybridCLIP"
4
+ ],
5
+ "initializer_factor": 1.0,
6
+ "model_type": "hybrid-clip",
7
+ "projection_dim": 512,
8
+ "seed": 42,
9
+ "text_config": {
10
+ "_name_or_path": "",
11
+ "add_cross_attention": false,
12
+ "architectures": [
13
+ "RobertaForMaskedLM"
14
+ ],
15
+ "attention_probs_dropout_prob": 0.1,
16
+ "bad_words_ids": null,
17
+ "bos_token_id": 0,
18
+ "chunk_size_feed_forward": 0,
19
+ "decoder_start_token_id": null,
20
+ "diversity_penalty": 0.0,
21
+ "do_sample": false,
22
+ "early_stopping": false,
23
+ "encoder_no_repeat_ngram_size": 0,
24
+ "eos_token_id": 2,
25
+ "finetuning_task": null,
26
+ "forced_bos_token_id": null,
27
+ "forced_eos_token_id": null,
28
+ "gradient_checkpointing": false,
29
+ "hidden_act": "gelu",
30
+ "hidden_dropout_prob": 0.1,
31
+ "hidden_size": 1024,
32
+ "id2label": {
33
+ "0": "LABEL_0",
34
+ "1": "LABEL_1"
35
+ },
36
+ "initializer_range": 0.02,
37
+ "intermediate_size": 4096,
38
+ "is_decoder": false,
39
+ "is_encoder_decoder": false,
40
+ "label2id": {
41
+ "LABEL_0": 0,
42
+ "LABEL_1": 1
43
+ },
44
+ "layer_norm_eps": 1e-05,
45
+ "length_penalty": 1.0,
46
+ "max_length": 20,
47
+ "max_position_embeddings": 512,
48
+ "min_length": 0,
49
+ "model_type": "roberta",
50
+ "no_repeat_ngram_size": 0,
51
+ "num_attention_heads": 16,
52
+ "num_beam_groups": 1,
53
+ "num_beams": 1,
54
+ "num_hidden_layers": 24,
55
+ "num_return_sequences": 1,
56
+ "output_attentions": false,
57
+ "output_hidden_states": false,
58
+ "output_scores": false,
59
+ "pad_token_id": 1,
60
+ "position_embedding_type": "absolute",
61
+ "prefix": null,
62
+ "problem_type": null,
63
+ "pruned_heads": {},
64
+ "remove_invalid_values": false,
65
+ "repetition_penalty": 1.0,
66
+ "return_dict": true,
67
+ "return_dict_in_generate": false,
68
+ "sep_token_id": null,
69
+ "task_specific_params": null,
70
+ "temperature": 1.0,
71
+ "tie_encoder_decoder": false,
72
+ "tie_word_embeddings": true,
73
+ "tokenizer_class": "BertTokenizer",
74
+ "top_k": 50,
75
+ "top_p": 1.0,
76
+ "torch_dtype": null,
77
+ "torchscript": false,
78
+ "transformers_version": "4.9.0.dev0",
79
+ "type_vocab_size": 1,
80
+ "use_bfloat16": false,
81
+ "use_cache": true,
82
+ "vocab_size": 32000
83
+ },
84
+ "transformers_version": null,
85
+ "vision_config": {
86
+ "_name_or_path": "",
87
+ "add_cross_attention": false,
88
+ "architectures": null,
89
+ "attention_dropout": 0.0,
90
+ "bad_words_ids": null,
91
+ "bos_token_id": null,
92
+ "chunk_size_feed_forward": 0,
93
+ "decoder_start_token_id": null,
94
+ "diversity_penalty": 0.0,
95
+ "do_sample": false,
96
+ "dropout": 0.0,
97
+ "early_stopping": false,
98
+ "encoder_no_repeat_ngram_size": 0,
99
+ "eos_token_id": null,
100
+ "finetuning_task": null,
101
+ "forced_bos_token_id": null,
102
+ "forced_eos_token_id": null,
103
+ "gradient_checkpointing": false,
104
+ "hidden_act": "quick_gelu",
105
+ "hidden_size": 768,
106
+ "id2label": {
107
+ "0": "LABEL_0",
108
+ "1": "LABEL_1"
109
+ },
110
+ "image_size": 224,
111
+ "initializer_factor": 1.0,
112
+ "initializer_range": 0.02,
113
+ "intermediate_size": 3072,
114
+ "is_decoder": false,
115
+ "is_encoder_decoder": false,
116
+ "label2id": {
117
+ "LABEL_0": 0,
118
+ "LABEL_1": 1
119
+ },
120
+ "layer_norm_eps": 1e-05,
121
+ "length_penalty": 1.0,
122
+ "max_length": 20,
123
+ "min_length": 0,
124
+ "model_type": "clip_vision_model",
125
+ "no_repeat_ngram_size": 0,
126
+ "num_attention_heads": 12,
127
+ "num_beam_groups": 1,
128
+ "num_beams": 1,
129
+ "num_hidden_layers": 12,
130
+ "num_return_sequences": 1,
131
+ "output_attentions": false,
132
+ "output_hidden_states": false,
133
+ "output_scores": false,
134
+ "pad_token_id": null,
135
+ "patch_size": 32,
136
+ "prefix": null,
137
+ "problem_type": null,
138
+ "pruned_heads": {},
139
+ "remove_invalid_values": false,
140
+ "repetition_penalty": 1.0,
141
+ "return_dict": true,
142
+ "return_dict_in_generate": false,
143
+ "sep_token_id": null,
144
+ "task_specific_params": null,
145
+ "temperature": 1.0,
146
+ "tie_encoder_decoder": false,
147
+ "tie_word_embeddings": true,
148
+ "tokenizer_class": null,
149
+ "top_k": 50,
150
+ "top_p": 1.0,
151
+ "torch_dtype": null,
152
+ "torchscript": false,
153
+ "transformers_version": "4.9.0.dev0",
154
+ "use_bfloat16": false
155
+ }
156
+ }
models/coco_only/flax_model.msgpack ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d1397edcc4c8f8e3c72fcb4a3cfdc742aa6ff727206f601e100e4df7398b2001
3
+ size 1700132358
requirements.txt CHANGED
File without changes
run_hybrid_clip.py CHANGED
@@ -31,24 +31,30 @@ from dataclasses import dataclass, field
31
  from pathlib import Path
32
  from typing import Callable, Optional
33
 
34
- import torch
35
- from torchvision.datasets import VisionDataset
36
- from torchvision.io import ImageReadMode, read_image
37
- from torchvision.transforms import CenterCrop, ConvertImageDtype, Normalize, Resize
38
- from torchvision.transforms.functional import InterpolationMode
39
- from tqdm import tqdm
40
-
41
  import jax
42
  import jax.numpy as jnp
43
  import optax
 
44
  import transformers
45
  from flax import jax_utils
46
  from flax.jax_utils import unreplicate
47
  from flax.training import train_state
48
  from flax.training.common_utils import get_metrics, shard, shard_prng_key
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  from modeling_hybrid_clip import FlaxHybridCLIP
50
- from transformers import AutoTokenizer, HfArgumentParser, TrainingArguments, is_tensorboard_available, set_seed
51
-
52
 
53
  logger = logging.getLogger(__name__)
54
 
@@ -59,7 +65,9 @@ if has_tensorboard:
59
  from flax.metrics.tensorboard import SummaryWriter
60
  except ImportError as ie:
61
  has_tensorboard = False
62
- print(f"Unable to display metrics through TensorBoard because some package are not installed: {ie}")
 
 
63
 
64
  else:
65
  print(
@@ -88,20 +96,33 @@ class ModelArguments:
88
  )
89
  from_pt: bool = field(
90
  default=True,
91
- metadata={"help": "whether to load the text and vision model using PyTorch checkpoints."},
 
 
92
  )
93
  config_name: Optional[str] = field(
94
- default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
 
 
 
95
  )
96
  tokenizer_name: Optional[str] = field(
97
- default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
 
 
 
98
  )
99
  cache_dir: Optional[str] = field(
100
- default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
 
 
 
101
  )
102
  use_fast_tokenizer: bool = field(
103
  default=True,
104
- metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
 
 
105
  )
106
  dtype: Optional[str] = field(
107
  default="float32",
@@ -117,9 +138,12 @@ class DataTrainingArguments:
117
  Arguments pertaining to what data we are going to input our model for training and eval.
118
  """
119
 
120
- data_dir: Optional[str] = field(default=None, metadata={"help": "The data directory containing input files."})
 
 
121
  train_file: Optional[str] = field(
122
- default=None, metadata={"help": "The input training data file (a jsonlines file)."}
 
123
  )
124
  validation_file: Optional[str] = field(
125
  default=None,
@@ -147,10 +171,12 @@ class DataTrainingArguments:
147
  },
148
  )
149
  overwrite_cache: bool = field(
150
- default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
 
151
  )
152
  overwrite_cache: bool = field(
153
- default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
 
154
  )
155
  preprocessing_num_workers: Optional[int] = field(
156
  default=None,
@@ -159,7 +185,9 @@ class DataTrainingArguments:
159
 
160
  def __post_init__(self):
161
  if self.train_file is None and self.validation_file is None:
162
- raise ValueError("Need either a dataset name or a training/validation file.")
 
 
163
  else:
164
  if self.train_file is not None:
165
  extension = self.train_file.split(".")[-1]
@@ -169,87 +197,13 @@ class DataTrainingArguments:
169
  assert extension == "json", "`validation_file` should be a json file."
170
 
171
 
172
- # We use torchvision for faster image pre-processing.
173
- # We need to ensure faster processing speed as it can become a bottleneck on TPU
174
- class Transform(torch.nn.Module):
175
- def __init__(self, image_size):
176
- super().__init__()
177
- self.transforms = torch.nn.Sequential(
178
- Resize([image_size], interpolation=InterpolationMode.BICUBIC),
179
- CenterCrop(image_size),
180
- ConvertImageDtype(torch.float),
181
- Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
182
- )
183
-
184
- def forward(self, x: torch.Tensor) -> torch.Tensor:
185
- with torch.no_grad():
186
- x = self.transforms(x)
187
- return x
188
-
189
-
190
- class ImageTextDataset(VisionDataset):
191
- """
192
- Dtaset for loading image-text data for tasks like CLIP training, Image Captioning.
193
- Args:
194
- root: (string): The root path where the dataset is stored
195
- file_path: (string): Path to the file containing the image_paths and associated captions.
196
- The expected format is jsonlines where each line is a json object containing to keys.
197
- `image_path`: The path to the image.
198
- `captions`: An `array` of captions.
199
- transform (callable, optional): A function/transform that takes in an PIL image
200
- and returns a transformed version. E.g, ``transforms.ToTensor``
201
- target_transform (callable, optional): A function/transform that takes in the
202
- target and transforms it.
203
- transforms (callable, optional): A function/transform that takes input sample and its target as entry
204
- and returns a transformed version.
205
- """
206
-
207
- def __init__(
208
- self,
209
- root: str,
210
- file_path: str,
211
- captions_per_image=2,
212
- transform: Optional[Callable] = None,
213
- target_transform: Optional[Callable] = None,
214
- transforms: Optional[Callable] = None,
215
- ):
216
- super().__init__(root, transforms, transform, target_transform)
217
-
218
- with open(file_path, "r") as f:
219
- examples = [json.loads(line) for line in f.readlines()]
220
-
221
- self.captions = []
222
- self.image_paths = []
223
-
224
- for example in examples:
225
- self.captions.extend(example["captions"][:captions_per_image])
226
- self.image_paths.extend([example["image_path"]] * captions_per_image)
227
-
228
- def _load_image(self, idx: int):
229
- path = self.image_paths[idx]
230
- return read_image(path, mode=ImageReadMode.RGB)
231
-
232
- def _load_target(self, idx):
233
- return self.captions[idx]
234
-
235
- def __getitem__(self, index: int):
236
- image = self._load_image(index)
237
- target = self._load_target(index)
238
-
239
- if self.transforms is not None:
240
- image, target = self.transforms(image, target)
241
-
242
- return image, target
243
-
244
- def __len__(self) -> int:
245
- return len(self.captions)
246
-
247
-
248
  class TrainState(train_state.TrainState):
249
  dropout_rng: jnp.ndarray
250
 
251
  def replicate(self):
252
- return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
 
 
253
 
254
 
255
  def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
@@ -266,25 +220,39 @@ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
266
 
267
 
268
  def create_learning_rate_fn(
269
- train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
 
 
 
 
270
  ) -> Callable[[int], jnp.array]:
271
  """Returns a linear warmup, linear_decay learning rate function."""
272
  steps_per_epoch = train_ds_size // train_batch_size
273
  num_train_steps = steps_per_epoch * num_train_epochs
274
- warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
 
 
275
  decay_fn = optax.linear_schedule(
276
- init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
 
 
 
 
 
277
  )
278
- schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
279
  return schedule_fn
280
 
281
 
282
  def main():
283
- parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
 
 
284
  if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
285
  # If we pass only one argument to the script and it's the path to a json file,
286
  # let's parse it to get our arguments.
287
- model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
 
 
288
  else:
289
  model_args, data_args, training_args = parser.parse_args_into_dataclasses()
290
 
@@ -317,11 +285,15 @@ def main():
317
 
318
  if model_args.tokenizer_name:
319
  tokenizer = AutoTokenizer.from_pretrained(
320
- model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
 
 
321
  )
322
  elif model_args.text_model_name_or_path:
323
  tokenizer = AutoTokenizer.from_pretrained(
324
- model_args.text_model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
 
 
325
  )
326
  else:
327
  raise ValueError(
@@ -349,29 +321,40 @@ def main():
349
  train_dataset = ImageTextDataset(
350
  data_args.data_dir,
351
  data_args.train_file,
352
- captions_per_image=2,
353
  transform=preprocess,
354
  )
355
 
356
  eval_dataset = ImageTextDataset(
357
  data_args.data_dir,
358
  data_args.validation_file,
359
- captions_per_image=1,
360
  transform=preprocess,
361
  )
362
 
363
  # Store some constant
364
  num_epochs = int(training_args.num_train_epochs)
365
- train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
 
 
366
  eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
367
  steps_per_epoch = len(train_dataset) // train_batch_size
368
  total_train_steps = steps_per_epoch * num_epochs
369
 
370
  # Use collate function to tokenizer the text and convert the processed images to numpy
371
  def collate_fn(examples):
372
- pixel_values = torch.stack([example[0] for example in examples]).permute(0, 2, 3, 1).numpy()
 
 
 
 
373
  captions = [example[1] for example in examples]
374
- inputs = tokenizer(captions, max_length=data_args.max_seq_length, padding="max_length", return_tensors="np")
 
 
 
 
 
375
 
376
  batch = {
377
  "pixel_values": pixel_values,
@@ -404,7 +387,9 @@ def main():
404
 
405
  # Enable tensorboard only on the master node
406
  if has_tensorboard and jax.process_index() == 0:
407
- summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir).joinpath("logs").as_posix())
 
 
408
 
409
  # Initialize our training
410
  rng = jax.random.PRNGKey(training_args.seed)
@@ -429,7 +414,9 @@ def main():
429
  )
430
 
431
  # Setup train state
432
- state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng)
 
 
433
 
434
  def cross_entropy(logits, axis):
435
  logprobs = jax.nn.log_softmax(logits, axis=axis)
@@ -438,7 +425,9 @@ def main():
438
  return ce
439
 
440
  def clip_loss(similarity):
441
- loss = (cross_entropy(similarity, axis=0) + cross_entropy(similarity, axis=1)) / 2
 
 
442
  return loss
443
 
444
  # Define gradient update step fn
@@ -446,7 +435,9 @@ def main():
446
  dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
447
 
448
  def compute_loss(params):
449
- logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
 
 
450
  loss = clip_loss(logits)
451
  return loss
452
 
@@ -456,7 +447,10 @@ def main():
456
 
457
  new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)
458
 
459
- metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
 
 
 
460
  metrics = jax.lax.pmean(metrics, axis_name="batch")
461
 
462
  return new_state, metrics
@@ -481,8 +475,12 @@ def main():
481
  logger.info("***** Running training *****")
482
  logger.info(f" Num examples = {len(train_dataset)}")
483
  logger.info(f" Num Epochs = {num_epochs}")
484
- logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
485
- logger.info(f" Total train batch size (w. parallel & distributed) = {train_batch_size}")
 
 
 
 
486
  logger.info(f" Total optimization steps = {total_train_steps}")
487
 
488
  train_time = 0
@@ -499,7 +497,9 @@ def main():
499
  train_metrics = []
500
 
501
  steps_per_epoch = len(train_dataset) // train_batch_size
502
- train_step_progress_bar = tqdm(total=steps_per_epoch, desc="Training...", position=1, leave=False)
 
 
503
  # train
504
  for batch in train_loader:
505
  batch = shard(batch)
@@ -520,7 +520,9 @@ def main():
520
  # ======================== Evaluating ==============================
521
  eval_metrics = []
522
  eval_steps = len(eval_dataset) // eval_batch_size
523
- eval_step_progress_bar = tqdm(total=eval_steps, desc="Evaluating...", position=2, leave=False)
 
 
524
  for batch in eval_loader:
525
  # Model forward
526
  batch = shard(batch)
@@ -536,14 +538,18 @@ def main():
536
 
537
  # Print metrics and update progress bar
538
  eval_step_progress_bar.close()
539
- desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']})"
 
 
540
  epochs.write(desc)
541
  epochs.desc = desc
542
 
543
  # Save metrics
544
  if has_tensorboard and jax.process_index() == 0:
545
  cur_step = epoch * (len(train_dataset) // train_batch_size)
546
- write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step)
 
 
547
 
548
  # save checkpoint after each epoch and push checkpoint to the hub
549
  if jax.process_index() == 0:
@@ -557,4 +563,4 @@ def main():
557
 
558
 
559
  if __name__ == "__main__":
560
- main()
 
31
  from pathlib import Path
32
  from typing import Callable, Optional
33
 
 
 
 
 
 
 
 
34
  import jax
35
  import jax.numpy as jnp
36
  import optax
37
+ import torch
38
  import transformers
39
  from flax import jax_utils
40
  from flax.jax_utils import unreplicate
41
  from flax.training import train_state
42
  from flax.training.common_utils import get_metrics, shard, shard_prng_key
43
+ from torchvision.datasets import VisionDataset
44
+ from torchvision.io import ImageReadMode, read_image
45
+ from torchvision.transforms import CenterCrop, ConvertImageDtype, Normalize, Resize
46
+ from torchvision.transforms.functional import InterpolationMode
47
+ from tqdm import tqdm
48
+ from transformers import (
49
+ AutoTokenizer,
50
+ HfArgumentParser,
51
+ TrainingArguments,
52
+ is_tensorboard_available,
53
+ set_seed,
54
+ )
55
+
56
+ from dataloader import ImageTextDataset, Transform
57
  from modeling_hybrid_clip import FlaxHybridCLIP
 
 
58
 
59
  logger = logging.getLogger(__name__)
60
 
 
65
  from flax.metrics.tensorboard import SummaryWriter
66
  except ImportError as ie:
67
  has_tensorboard = False
68
+ print(
69
+ f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
70
+ )
71
 
72
  else:
73
  print(
 
96
  )
97
  from_pt: bool = field(
98
  default=True,
99
+ metadata={
100
+ "help": "whether to load the text and vision model using PyTorch checkpoints."
101
+ },
102
  )
103
  config_name: Optional[str] = field(
104
+ default=None,
105
+ metadata={
106
+ "help": "Pretrained config name or path if not the same as model_name"
107
+ },
108
  )
109
  tokenizer_name: Optional[str] = field(
110
+ default=None,
111
+ metadata={
112
+ "help": "Pretrained tokenizer name or path if not the same as model_name"
113
+ },
114
  )
115
  cache_dir: Optional[str] = field(
116
+ default=None,
117
+ metadata={
118
+ "help": "Where do you want to store the pretrained models downloaded from s3"
119
+ },
120
  )
121
  use_fast_tokenizer: bool = field(
122
  default=True,
123
+ metadata={
124
+ "help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."
125
+ },
126
  )
127
  dtype: Optional[str] = field(
128
  default="float32",
 
138
  Arguments pertaining to what data we are going to input our model for training and eval.
139
  """
140
 
141
+ data_dir: Optional[str] = field(
142
+ default=None, metadata={"help": "The data directory containing input files."}
143
+ )
144
  train_file: Optional[str] = field(
145
+ default=None,
146
+ metadata={"help": "The input training data file (a jsonlines file)."},
147
  )
148
  validation_file: Optional[str] = field(
149
  default=None,
 
171
  },
172
  )
173
  overwrite_cache: bool = field(
174
+ default=False,
175
+ metadata={"help": "Overwrite the cached training and evaluation sets"},
176
  )
177
  overwrite_cache: bool = field(
178
+ default=False,
179
+ metadata={"help": "Overwrite the cached training and evaluation sets"},
180
  )
181
  preprocessing_num_workers: Optional[int] = field(
182
  default=None,
 
185
 
186
  def __post_init__(self):
187
  if self.train_file is None and self.validation_file is None:
188
+ raise ValueError(
189
+ "Need either a dataset name or a training/validation file."
190
+ )
191
  else:
192
  if self.train_file is not None:
193
  extension = self.train_file.split(".")[-1]
 
197
  assert extension == "json", "`validation_file` should be a json file."
198
 
199
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
  class TrainState(train_state.TrainState):
201
  dropout_rng: jnp.ndarray
202
 
203
  def replicate(self):
204
+ return jax_utils.replicate(self).replace(
205
+ dropout_rng=shard_prng_key(self.dropout_rng)
206
+ )
207
 
208
 
209
  def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
 
220
 
221
 
222
  def create_learning_rate_fn(
223
+ train_ds_size: int,
224
+ train_batch_size: int,
225
+ num_train_epochs: int,
226
+ num_warmup_steps: int,
227
+ learning_rate: float,
228
  ) -> Callable[[int], jnp.array]:
229
  """Returns a linear warmup, linear_decay learning rate function."""
230
  steps_per_epoch = train_ds_size // train_batch_size
231
  num_train_steps = steps_per_epoch * num_train_epochs
232
+ warmup_fn = optax.linear_schedule(
233
+ init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps
234
+ )
235
  decay_fn = optax.linear_schedule(
236
+ init_value=learning_rate,
237
+ end_value=0,
238
+ transition_steps=num_train_steps - num_warmup_steps,
239
+ )
240
+ schedule_fn = optax.join_schedules(
241
+ schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps]
242
  )
 
243
  return schedule_fn
244
 
245
 
246
  def main():
247
+ parser = HfArgumentParser(
248
+ (ModelArguments, DataTrainingArguments, TrainingArguments)
249
+ )
250
  if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
251
  # If we pass only one argument to the script and it's the path to a json file,
252
  # let's parse it to get our arguments.
253
+ model_args, data_args, training_args = parser.parse_json_file(
254
+ json_file=os.path.abspath(sys.argv[1])
255
+ )
256
  else:
257
  model_args, data_args, training_args = parser.parse_args_into_dataclasses()
258
 
 
285
 
286
  if model_args.tokenizer_name:
287
  tokenizer = AutoTokenizer.from_pretrained(
288
+ model_args.tokenizer_name,
289
+ cache_dir=model_args.cache_dir,
290
+ use_fast=model_args.use_fast_tokenizer,
291
  )
292
  elif model_args.text_model_name_or_path:
293
  tokenizer = AutoTokenizer.from_pretrained(
294
+ model_args.text_model_name_or_path,
295
+ cache_dir=model_args.cache_dir,
296
+ use_fast=model_args.use_fast_tokenizer,
297
  )
298
  else:
299
  raise ValueError(
 
321
  train_dataset = ImageTextDataset(
322
  data_args.data_dir,
323
  data_args.train_file,
324
+ captions_per_image=5,
325
  transform=preprocess,
326
  )
327
 
328
  eval_dataset = ImageTextDataset(
329
  data_args.data_dir,
330
  data_args.validation_file,
331
+ captions_per_image=5,
332
  transform=preprocess,
333
  )
334
 
335
  # Store some constant
336
  num_epochs = int(training_args.num_train_epochs)
337
+ train_batch_size = (
338
+ int(training_args.per_device_train_batch_size) * jax.device_count()
339
+ )
340
  eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
341
  steps_per_epoch = len(train_dataset) // train_batch_size
342
  total_train_steps = steps_per_epoch * num_epochs
343
 
344
  # Use collate function to tokenizer the text and convert the processed images to numpy
345
  def collate_fn(examples):
346
+ pixel_values = (
347
+ torch.stack([example[0] for example in examples])
348
+ .permute(0, 2, 3, 1)
349
+ .numpy()
350
+ )
351
  captions = [example[1] for example in examples]
352
+ inputs = tokenizer(
353
+ captions,
354
+ max_length=data_args.max_seq_length,
355
+ padding="max_length",
356
+ return_tensors="np",
357
+ )
358
 
359
  batch = {
360
  "pixel_values": pixel_values,
 
387
 
388
  # Enable tensorboard only on the master node
389
  if has_tensorboard and jax.process_index() == 0:
390
+ summary_writer = SummaryWriter(
391
+ log_dir=Path(training_args.output_dir).joinpath("logs").as_posix()
392
+ )
393
 
394
  # Initialize our training
395
  rng = jax.random.PRNGKey(training_args.seed)
 
414
  )
415
 
416
  # Setup train state
417
+ state = TrainState.create(
418
+ apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng
419
+ )
420
 
421
  def cross_entropy(logits, axis):
422
  logprobs = jax.nn.log_softmax(logits, axis=axis)
 
425
  return ce
426
 
427
  def clip_loss(similarity):
428
+ loss = (
429
+ cross_entropy(similarity, axis=0) + cross_entropy(similarity, axis=1)
430
+ ) / 2
431
  return loss
432
 
433
  # Define gradient update step fn
 
435
  dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
436
 
437
  def compute_loss(params):
438
+ logits = state.apply_fn(
439
+ **batch, params=params, dropout_rng=dropout_rng, train=True
440
+ )[0]
441
  loss = clip_loss(logits)
442
  return loss
443
 
 
447
 
448
  new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)
449
 
450
+ metrics = {
451
+ "loss": loss,
452
+ "learning_rate": linear_decay_lr_schedule_fn(state.step),
453
+ }
454
  metrics = jax.lax.pmean(metrics, axis_name="batch")
455
 
456
  return new_state, metrics
 
475
  logger.info("***** Running training *****")
476
  logger.info(f" Num examples = {len(train_dataset)}")
477
  logger.info(f" Num Epochs = {num_epochs}")
478
+ logger.info(
479
+ f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}"
480
+ )
481
+ logger.info(
482
+ f" Total train batch size (w. parallel & distributed) = {train_batch_size}"
483
+ )
484
  logger.info(f" Total optimization steps = {total_train_steps}")
485
 
486
  train_time = 0
 
497
  train_metrics = []
498
 
499
  steps_per_epoch = len(train_dataset) // train_batch_size
500
+ train_step_progress_bar = tqdm(
501
+ total=steps_per_epoch, desc="Training...", position=1, leave=False
502
+ )
503
  # train
504
  for batch in train_loader:
505
  batch = shard(batch)
 
520
  # ======================== Evaluating ==============================
521
  eval_metrics = []
522
  eval_steps = len(eval_dataset) // eval_batch_size
523
+ eval_step_progress_bar = tqdm(
524
+ total=eval_steps, desc="Evaluating...", position=2, leave=False
525
+ )
526
  for batch in eval_loader:
527
  # Model forward
528
  batch = shard(batch)
 
538
 
539
  # Print metrics and update progress bar
540
  eval_step_progress_bar.close()
541
+ desc = (
542
+ f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']})"
543
+ )
544
  epochs.write(desc)
545
  epochs.desc = desc
546
 
547
  # Save metrics
548
  if has_tensorboard and jax.process_index() == 0:
549
  cur_step = epoch * (len(train_dataset) // train_batch_size)
550
+ write_metric(
551
+ summary_writer, train_metrics, eval_metrics, train_time, cur_step
552
+ )
553
 
554
  # save checkpoint after each epoch and push checkpoint to the hub
555
  if jax.process_index() == 0:
 
563
 
564
 
565
  if __name__ == "__main__":
566
+ main()
run_hybrid_clip_en.py DELETED
@@ -1,570 +0,0 @@
1
- #!/usr/bin/env python
2
- # coding=utf-8
3
- # Copyright 2021 The HuggingFace Team All rights reserved.
4
- #
5
- # Licensed under the Apache License, Version 2.0 (the "License");
6
- # you may not use this file except in compliance with the License.
7
- # You may obtain a copy of the License at
8
- #
9
- # http://www.apache.org/licenses/LICENSE-2.0
10
- #
11
- # Unless required by applicable law or agreed to in writing, software
12
- # distributed under the License is distributed on an "AS IS" BASIS,
13
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
- # See the License for the specific language governing permissions and
15
- # limitations under the License.
16
- """
17
- Training a CLIP like dual encoder models using text and vision encoders in the library.
18
- The script can be used to train CLIP like models for languages other than english by using
19
- a text encoder pre-trained in the desired language. Currently this script support the following vision
20
- and text models:
21
- Vision models: ViT(https://huggingface.co/models?filter=vit), CLIP (https://huggingface.co/models?filter=clip)
22
- Text models: BERT, ROBERTa (https://huggingface.co/models?filter=masked-lm)
23
- """
24
-
25
- import json
26
- import logging
27
- import os
28
- import sys
29
- import time
30
- from dataclasses import dataclass, field
31
- from pathlib import Path
32
- from typing import Callable, Optional
33
-
34
- import jax
35
- import jax.numpy as jnp
36
- import optax
37
- import torch
38
- import transformers
39
- from flax import jax_utils
40
- from flax.jax_utils import unreplicate
41
- from flax.training import train_state
42
- from flax.training.common_utils import get_metrics, shard, shard_prng_key
43
- from pororo import Pororo
44
- from torchvision.datasets import VisionDataset
45
- from torchvision.io import ImageReadMode, read_image
46
- from torchvision.transforms import CenterCrop, ConvertImageDtype, Normalize, Resize
47
- from torchvision.transforms.functional import InterpolationMode
48
- from tqdm import tqdm
49
- from transformers import (
50
- AutoTokenizer,
51
- HfArgumentParser,
52
- TrainingArguments,
53
- is_tensorboard_available,
54
- set_seed,
55
- )
56
-
57
- from dataloader import ImageTextDataset, Transform
58
- from modeling_hybrid_clip import FlaxHybridCLIP
59
-
60
- logger = logging.getLogger(__name__)
61
-
62
- # Cache the result
63
- has_tensorboard = is_tensorboard_available()
64
- if has_tensorboard:
65
- try:
66
- from flax.metrics.tensorboard import SummaryWriter
67
- except ImportError as ie:
68
- has_tensorboard = False
69
- print(
70
- f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
71
- )
72
-
73
- else:
74
- print(
75
- "Unable to display metrics through TensorBoard because the package is not installed: "
76
- "Please run pip install tensorboard to enable."
77
- )
78
-
79
-
80
- @dataclass
81
- class ModelArguments:
82
- """
83
- Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
84
- """
85
-
86
- text_model_name_or_path: str = field(
87
- metadata={
88
- "help": "The text model checkpoint for weights initialization."
89
- "Don't set if you want to train a model from scratch."
90
- },
91
- )
92
- vision_model_name_or_path: str = field(
93
- metadata={
94
- "help": "The vision model checkpoint for weights initialization."
95
- "Don't set if you want to train a model from scratch."
96
- },
97
- )
98
- from_pt: bool = field(
99
- default=True,
100
- metadata={
101
- "help": "whether to load the text and vision model using PyTorch checkpoints."
102
- },
103
- )
104
- config_name: Optional[str] = field(
105
- default=None,
106
- metadata={
107
- "help": "Pretrained config name or path if not the same as model_name"
108
- },
109
- )
110
- tokenizer_name: Optional[str] = field(
111
- default=None,
112
- metadata={
113
- "help": "Pretrained tokenizer name or path if not the same as model_name"
114
- },
115
- )
116
- cache_dir: Optional[str] = field(
117
- default=None,
118
- metadata={
119
- "help": "Where do you want to store the pretrained models downloaded from s3"
120
- },
121
- )
122
- use_fast_tokenizer: bool = field(
123
- default=True,
124
- metadata={
125
- "help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."
126
- },
127
- )
128
- dtype: Optional[str] = field(
129
- default="float32",
130
- metadata={
131
- "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
132
- },
133
- )
134
-
135
-
136
- @dataclass
137
- class DataTrainingArguments:
138
- """
139
- Arguments pertaining to what data we are going to input our model for training and eval.
140
- """
141
-
142
- data_dir: Optional[str] = field(
143
- default=None, metadata={"help": "The data directory containing input files."}
144
- )
145
- train_file: Optional[str] = field(
146
- default=None,
147
- metadata={"help": "The input training data file (a jsonlines file)."},
148
- )
149
- validation_file: Optional[str] = field(
150
- default=None,
151
- metadata={"help": "An optional input evaluation data file (a jsonlines file)."},
152
- )
153
- max_seq_length: Optional[int] = field(
154
- default=72,
155
- metadata={
156
- "help": "The maximum total input sequence length after tokenization. Sequences longer "
157
- "than this will be truncated, sequences shorter will be padded."
158
- },
159
- )
160
- max_train_samples: Optional[int] = field(
161
- default=None,
162
- metadata={
163
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
164
- "value if set."
165
- },
166
- )
167
- max_eval_samples: Optional[int] = field(
168
- default=None,
169
- metadata={
170
- "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
171
- "value if set."
172
- },
173
- )
174
- overwrite_cache: bool = field(
175
- default=False,
176
- metadata={"help": "Overwrite the cached training and evaluation sets"},
177
- )
178
- overwrite_cache: bool = field(
179
- default=False,
180
- metadata={"help": "Overwrite the cached training and evaluation sets"},
181
- )
182
- preprocessing_num_workers: Optional[int] = field(
183
- default=None,
184
- metadata={"help": "The number of processes to use for the preprocessing."},
185
- )
186
-
187
- def __post_init__(self):
188
- if self.train_file is None and self.validation_file is None:
189
- raise ValueError(
190
- "Need either a dataset name or a training/validation file."
191
- )
192
- else:
193
- if self.train_file is not None:
194
- extension = self.train_file.split(".")[-1]
195
- assert extension == "json", "`train_file` should be a json file."
196
- if self.validation_file is not None:
197
- extension = self.validation_file.split(".")[-1]
198
- assert extension == "json", "`validation_file` should be a json file."
199
-
200
-
201
- class TrainState(train_state.TrainState):
202
- dropout_rng: jnp.ndarray
203
-
204
- def replicate(self):
205
- return jax_utils.replicate(self).replace(
206
- dropout_rng=shard_prng_key(self.dropout_rng)
207
- )
208
-
209
-
210
- def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
211
- summary_writer.scalar("train_time", train_time, step)
212
-
213
- train_metrics = get_metrics(train_metrics)
214
- for key, vals in train_metrics.items():
215
- tag = f"train_{key}"
216
- for i, val in enumerate(vals):
217
- summary_writer.scalar(tag, val, step - len(vals) + i + 1)
218
-
219
- for metric_name, value in eval_metrics.items():
220
- summary_writer.scalar(f"eval_{metric_name}", value, step)
221
-
222
-
223
- def create_learning_rate_fn(
224
- train_ds_size: int,
225
- train_batch_size: int,
226
- num_train_epochs: int,
227
- num_warmup_steps: int,
228
- learning_rate: float,
229
- ) -> Callable[[int], jnp.array]:
230
- """Returns a linear warmup, linear_decay learning rate function."""
231
- steps_per_epoch = train_ds_size // train_batch_size
232
- num_train_steps = steps_per_epoch * num_train_epochs
233
- warmup_fn = optax.linear_schedule(
234
- init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps
235
- )
236
- decay_fn = optax.linear_schedule(
237
- init_value=learning_rate,
238
- end_value=0,
239
- transition_steps=num_train_steps - num_warmup_steps,
240
- )
241
- schedule_fn = optax.join_schedules(
242
- schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps]
243
- )
244
- return schedule_fn
245
-
246
-
247
- def main():
248
- parser = HfArgumentParser(
249
- (ModelArguments, DataTrainingArguments, TrainingArguments)
250
- )
251
- if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
252
- # If we pass only one argument to the script and it's the path to a json file,
253
- # let's parse it to get our arguments.
254
- model_args, data_args, training_args = parser.parse_json_file(
255
- json_file=os.path.abspath(sys.argv[1])
256
- )
257
- else:
258
- model_args, data_args, training_args = parser.parse_args_into_dataclasses()
259
-
260
- if (
261
- os.path.exists(training_args.output_dir)
262
- and os.listdir(training_args.output_dir)
263
- and training_args.do_train
264
- and not training_args.overwrite_output_dir
265
- ):
266
- raise ValueError(
267
- f"Output directory ({training_args.output_dir}) already exists and is not empty."
268
- "Use --overwrite_output_dir to overcome."
269
- )
270
-
271
- # Make one log on every process with the configuration for debugging.
272
- logging.basicConfig(
273
- format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
274
- datefmt="%m/%d/%Y %H:%M:%S",
275
- level=logging.INFO,
276
- )
277
- # Setup logging, we only want one process per machine to log things on the screen.
278
- logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
279
- if jax.process_index() == 0:
280
- transformers.utils.logging.set_verbosity_info()
281
- else:
282
- transformers.utils.logging.set_verbosity_error()
283
-
284
- # Set the verbosity to info of the Transformers logger (on main process only):
285
- logger.info(f"Training/evaluation parameters {training_args}")
286
-
287
- if model_args.tokenizer_name:
288
- tokenizer = AutoTokenizer.from_pretrained(
289
- model_args.tokenizer_name,
290
- cache_dir=model_args.cache_dir,
291
- use_fast=model_args.use_fast_tokenizer,
292
- )
293
- elif model_args.text_model_name_or_path:
294
- tokenizer = AutoTokenizer.from_pretrained(
295
- model_args.text_model_name_or_path,
296
- cache_dir=model_args.cache_dir,
297
- use_fast=model_args.use_fast_tokenizer,
298
- )
299
- else:
300
- raise ValueError(
301
- "You are instantiating a new tokenizer from scratch. This is not supported by this script."
302
- "You can do it from another script, save it, and load it from here, using --tokenizer_name."
303
- )
304
-
305
- model = FlaxHybridCLIP.from_text_vision_pretrained(
306
- model_args.text_model_name_or_path,
307
- model_args.vision_model_name_or_path,
308
- seed=training_args.seed,
309
- dtype=getattr(jnp, model_args.dtype),
310
- text_from_pt=model_args.from_pt,
311
- vision_from_pt=model_args.from_pt,
312
- )
313
- config = model.config
314
- # set seed for torch dataloaders
315
- set_seed(training_args.seed)
316
-
317
- # Initialize torchvision transforms and jit them for faster processing
318
- preprocess = Transform(config.vision_config.image_size)
319
- preprocess = torch.jit.script(preprocess)
320
-
321
- # Initialize the image-text dataset
322
- train_dataset = ImageTextDataset(
323
- data_args.data_dir,
324
- data_args.train_file,
325
- captions_per_image=2,
326
- transform=preprocess,
327
- )
328
-
329
- eval_dataset = ImageTextDataset(
330
- data_args.data_dir,
331
- data_args.validation_file,
332
- captions_per_image=1,
333
- transform=preprocess,
334
- )
335
- # Import Translation Pipeline
336
- mt = Pororo(task="translation", lang="multi")
337
-
338
- # Store some constant
339
- num_epochs = int(training_args.num_train_epochs)
340
- train_batch_size = (
341
- int(training_args.per_device_train_batch_size) * jax.device_count()
342
- )
343
- eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
344
- steps_per_epoch = len(train_dataset) // train_batch_size
345
- total_train_steps = steps_per_epoch * num_epochs
346
-
347
- # Use collate function to tokenizer the text and convert the processed images to numpy
348
- def collate_fn(examples):
349
- pixel_values = (
350
- torch.stack([example[0] for example in examples])
351
- .permute(0, 2, 3, 1)
352
- .numpy()
353
- )
354
- en_captions = [example[1] for example in examples]
355
- captions = [mt(text, src="en", tgt="ko") for text in en_captions]
356
- inputs = tokenizer(
357
- captions,
358
- max_length=data_args.max_seq_length,
359
- padding="max_length",
360
- return_tensors="np",
361
- )
362
-
363
- batch = {
364
- "pixel_values": pixel_values,
365
- "input_ids": inputs["input_ids"],
366
- "attention_mask": inputs["attention_mask"],
367
- }
368
-
369
- return batch
370
-
371
- # Create data loaders
372
- train_loader = torch.utils.data.DataLoader(
373
- train_dataset,
374
- batch_size=train_batch_size,
375
- shuffle=True,
376
- num_workers=data_args.preprocessing_num_workers,
377
- persistent_workers=True,
378
- drop_last=True,
379
- collate_fn=collate_fn,
380
- )
381
-
382
- eval_loader = torch.utils.data.DataLoader(
383
- eval_dataset,
384
- batch_size=eval_batch_size,
385
- shuffle=False,
386
- num_workers=data_args.preprocessing_num_workers,
387
- persistent_workers=True,
388
- drop_last=True,
389
- collate_fn=collate_fn,
390
- )
391
-
392
- # Enable tensorboard only on the master node
393
- if has_tensorboard and jax.process_index() == 0:
394
- summary_writer = SummaryWriter(
395
- log_dir=Path(training_args.output_dir).joinpath("logs").as_posix()
396
- )
397
-
398
- # Initialize our training
399
- rng = jax.random.PRNGKey(training_args.seed)
400
- rng, dropout_rng = jax.random.split(rng)
401
-
402
- # Create learning rate schedule
403
- linear_decay_lr_schedule_fn = create_learning_rate_fn(
404
- len(train_dataset),
405
- train_batch_size,
406
- training_args.num_train_epochs,
407
- training_args.warmup_steps,
408
- training_args.learning_rate,
409
- )
410
-
411
- # create adam optimizer
412
- adamw = optax.adamw(
413
- learning_rate=linear_decay_lr_schedule_fn,
414
- b1=training_args.adam_beta1,
415
- b2=training_args.adam_beta2,
416
- eps=training_args.adam_epsilon,
417
- weight_decay=training_args.weight_decay,
418
- )
419
-
420
- # Setup train state
421
- state = TrainState.create(
422
- apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng
423
- )
424
-
425
- def cross_entropy(logits, axis):
426
- logprobs = jax.nn.log_softmax(logits, axis=axis)
427
- nll = jnp.diag(logprobs)
428
- ce = -jnp.mean(nll)
429
- return ce
430
-
431
- def clip_loss(similarity):
432
- loss = (
433
- cross_entropy(similarity, axis=0) + cross_entropy(similarity, axis=1)
434
- ) / 2
435
- return loss
436
-
437
- # Define gradient update step fn
438
- def train_step(state, batch):
439
- dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
440
-
441
- def compute_loss(params):
442
- logits = state.apply_fn(
443
- **batch, params=params, dropout_rng=dropout_rng, train=True
444
- )[0]
445
- loss = clip_loss(logits)
446
- return loss
447
-
448
- grad_fn = jax.value_and_grad(compute_loss)
449
- loss, grad = grad_fn(state.params)
450
- grad = jax.lax.pmean(grad, "batch")
451
-
452
- new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)
453
-
454
- metrics = {
455
- "loss": loss,
456
- "learning_rate": linear_decay_lr_schedule_fn(state.step),
457
- }
458
- metrics = jax.lax.pmean(metrics, axis_name="batch")
459
-
460
- return new_state, metrics
461
-
462
- # Define eval fn
463
- def eval_step(params, batch):
464
- logits = model(**batch, params=params, train=False)[0]
465
- loss = clip_loss(logits)
466
-
467
- # summarize metrics
468
- metrics = {"loss": loss}
469
- metrics = jax.lax.pmean(metrics, axis_name="batch")
470
- return metrics
471
-
472
- # Create parallel version of the train and eval step
473
- p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
474
- p_eval_step = jax.pmap(eval_step, "batch")
475
-
476
- # Replicate the train state on each device
477
- state = state.replicate()
478
-
479
- logger.info("***** Running training *****")
480
- logger.info(f" Num examples = {len(train_dataset)}")
481
- logger.info(f" Num Epochs = {num_epochs}")
482
- logger.info(
483
- f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}"
484
- )
485
- logger.info(
486
- f" Total train batch size (w. parallel & distributed) = {train_batch_size}"
487
- )
488
- logger.info(f" Total optimization steps = {total_train_steps}")
489
-
490
- train_time = 0
491
- # Create sampling rng
492
- rng, input_rng = jax.random.split(rng)
493
-
494
- epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
495
- for epoch in epochs:
496
- # ======================== Training ================================
497
- train_start = time.time()
498
-
499
- # Create sampling rng
500
- rng, input_rng = jax.random.split(rng)
501
- train_metrics = []
502
-
503
- steps_per_epoch = len(train_dataset) // train_batch_size
504
- train_step_progress_bar = tqdm(
505
- total=steps_per_epoch, desc="Training...", position=1, leave=False
506
- )
507
- # train
508
- for batch in train_loader:
509
- batch = shard(batch)
510
- state, train_metric = p_train_step(state, batch)
511
- train_metrics.append(train_metric)
512
-
513
- train_step_progress_bar.update(1)
514
-
515
- train_time += time.time() - train_start
516
-
517
- train_metric = unreplicate(train_metric)
518
-
519
- train_step_progress_bar.close()
520
- epochs.write(
521
- f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
522
- )
523
-
524
- # ======================== Evaluating ==============================
525
- eval_metrics = []
526
- eval_steps = len(eval_dataset) // eval_batch_size
527
- eval_step_progress_bar = tqdm(
528
- total=eval_steps, desc="Evaluating...", position=2, leave=False
529
- )
530
- for batch in eval_loader:
531
- # Model forward
532
- batch = shard(batch)
533
- metrics = p_eval_step(state.params, batch)
534
- eval_metrics.append(metrics)
535
-
536
- eval_step_progress_bar.update(1)
537
-
538
- # normalize eval metrics
539
- eval_metrics = get_metrics(eval_metrics)
540
-
541
- eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
542
-
543
- # Print metrics and update progress bar
544
- eval_step_progress_bar.close()
545
- desc = (
546
- f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']})"
547
- )
548
- epochs.write(desc)
549
- epochs.desc = desc
550
-
551
- # Save metrics
552
- if has_tensorboard and jax.process_index() == 0:
553
- cur_step = epoch * (len(train_dataset) // train_batch_size)
554
- write_metric(
555
- summary_writer, train_metrics, eval_metrics, train_time, cur_step
556
- )
557
-
558
- # save checkpoint after each epoch and push checkpoint to the hub
559
- if jax.process_index() == 0:
560
- params = jax.device_get(unreplicate(state.params))
561
- model.save_pretrained(
562
- training_args.output_dir,
563
- params=params,
564
- push_to_hub=training_args.push_to_hub,
565
- commit_message=f"Saving weights and logs of epoch {epoch+1}",
566
- )
567
-
568
-
569
- if __name__ == "__main__":
570
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
train.sh CHANGED
@@ -1,15 +1,14 @@
1
  python run_hybrid_clip.py \
2
- --output_dir . \
3
  --text_model_name_or_path="klue/roberta-large" \
4
  --vision_model_name_or_path="openai/clip-vit-base-patch32" \
5
  --tokenizer_name="klue/roberta-large" \
6
- --train_file="coco_dataset/train_dataset.json" \
7
- --validation_file="coco_dataset/validation_dataset.json" \
8
  --do_train --do_eval \
9
  --num_train_epochs="40" --max_seq_length 96 \
10
  --per_device_train_batch_size="64" \
11
  --per_device_eval_batch_size="64" \
12
  --learning_rate="5e-5" --warmup_steps="0" --weight_decay 0.1 \
13
  --overwrite_output_dir \
14
- --preprocessing_num_workers 32 \
15
- --push_to_hub
 
1
  python run_hybrid_clip.py \
2
+ --output_dir="models/coco_only" \
3
  --text_model_name_or_path="klue/roberta-large" \
4
  --vision_model_name_or_path="openai/clip-vit-base-patch32" \
5
  --tokenizer_name="klue/roberta-large" \
6
+ --train_file="../dataset/coco/train_annotations.json" \
7
+ --validation_file="../dataset/coco/valid_annotations.json" \
8
  --do_train --do_eval \
9
  --num_train_epochs="40" --max_seq_length 96 \
10
  --per_device_train_batch_size="64" \
11
  --per_device_eval_batch_size="64" \
12
  --learning_rate="5e-5" --warmup_steps="0" --weight_decay 0.1 \
13
  --overwrite_output_dir \
14
+ --preprocessing_num_workers 32