feature: add coco_only model ckpt
Browse files- .gitattributes +0 -0
- .gitignore +0 -0
- Makefile +0 -0
- README.md +2 -2
- configuration_hybrid_clip.py +0 -0
- dataloader.py +3 -3
- down_wit.py +0 -79
- modeling_hybrid_clip.py +0 -0
- models/coco_only/config.json +156 -0
- models/coco_only/flax_model.msgpack +3 -0
- requirements.txt +0 -0
- run_hybrid_clip.py +128 -122
- run_hybrid_clip_en.py +0 -570
- train.sh +4 -5
.gitattributes
CHANGED
|
File without changes
|
.gitignore
CHANGED
|
File without changes
|
Makefile
CHANGED
|
File without changes
|
README.md
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
-
#
|
| 2 |
|
| 3 |
-
|
| 4 |
|
| 5 |
## Installation
|
| 6 |
|
|
|
|
| 1 |
+
# KoCLIP
|
| 2 |
|
| 3 |
+
This repository includes
|
| 4 |
|
| 5 |
## Installation
|
| 6 |
|
configuration_hybrid_clip.py
CHANGED
|
File without changes
|
dataloader.py
CHANGED
|
@@ -53,7 +53,7 @@ class ImageTextDataset(VisionDataset):
|
|
| 53 |
self,
|
| 54 |
root: str,
|
| 55 |
file_path: str,
|
| 56 |
-
captions_per_image=
|
| 57 |
transform: Optional[Callable] = None,
|
| 58 |
target_transform: Optional[Callable] = None,
|
| 59 |
transforms: Optional[Callable] = None,
|
|
@@ -61,7 +61,7 @@ class ImageTextDataset(VisionDataset):
|
|
| 61 |
super().__init__(root, transforms, transform, target_transform)
|
| 62 |
|
| 63 |
with open(file_path, "r") as f:
|
| 64 |
-
examples =
|
| 65 |
|
| 66 |
self.captions = []
|
| 67 |
self.image_paths = []
|
|
@@ -69,7 +69,7 @@ class ImageTextDataset(VisionDataset):
|
|
| 69 |
for example in examples:
|
| 70 |
captions = example["captions"][:captions_per_image]
|
| 71 |
self.captions.extend(captions)
|
| 72 |
-
self.image_paths.extend([example["
|
| 73 |
|
| 74 |
def _load_image(self, idx: int):
|
| 75 |
path = self.image_paths[idx]
|
|
|
|
| 53 |
self,
|
| 54 |
root: str,
|
| 55 |
file_path: str,
|
| 56 |
+
captions_per_image=5,
|
| 57 |
transform: Optional[Callable] = None,
|
| 58 |
target_transform: Optional[Callable] = None,
|
| 59 |
transforms: Optional[Callable] = None,
|
|
|
|
| 61 |
super().__init__(root, transforms, transform, target_transform)
|
| 62 |
|
| 63 |
with open(file_path, "r") as f:
|
| 64 |
+
examples = json.load(f)
|
| 65 |
|
| 66 |
self.captions = []
|
| 67 |
self.image_paths = []
|
|
|
|
| 69 |
for example in examples:
|
| 70 |
captions = example["captions"][:captions_per_image]
|
| 71 |
self.captions.extend(captions)
|
| 72 |
+
self.image_paths.extend([example["file_path"]] * len(captions))
|
| 73 |
|
| 74 |
def _load_image(self, idx: int):
|
| 75 |
path = self.image_paths[idx]
|
down_wit.py
DELETED
|
@@ -1,79 +0,0 @@
|
|
| 1 |
-
|
| 2 |
-
import csv
|
| 3 |
-
import glob
|
| 4 |
-
from typing import Text, List
|
| 5 |
-
import urllib.request
|
| 6 |
-
import requests
|
| 7 |
-
from multiprocessing import Pool
|
| 8 |
-
import socket
|
| 9 |
-
timeout = 10
|
| 10 |
-
socket.setdefaulttimeout(timeout)
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
DATA_PATH='/home/shared/dataset/wit'
|
| 14 |
-
# DATA_PATH='../data/wit'
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
def load_file(path):
|
| 18 |
-
"""
|
| 19 |
-
load csv
|
| 20 |
-
"""
|
| 21 |
-
with open(path) as f:
|
| 22 |
-
reader = csv.reader(f, delimiter='\t', quotechar='"')
|
| 23 |
-
data = list(reader)
|
| 24 |
-
return data
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
def extract_ko(data):
|
| 28 |
-
"""
|
| 29 |
-
Extract lang=ko data samples
|
| 30 |
-
"""
|
| 31 |
-
trainset = []
|
| 32 |
-
for samp in data[1:]:
|
| 33 |
-
if samp[0] != 'ko':
|
| 34 |
-
continue
|
| 35 |
-
trainset.append(samp)
|
| 36 |
-
return trainset
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
def rewrite_wit(data_paths):
|
| 40 |
-
"""
|
| 41 |
-
we need only korean set. extract only korean set.
|
| 42 |
-
https://drive.google.com/file/d/1y_DxYrmUF4vw3m7UOlVsHSkcO_v0XuLv/view?usp=sharing
|
| 43 |
-
|
| 44 |
-
"""
|
| 45 |
-
samples = []
|
| 46 |
-
for path in data_paths:
|
| 47 |
-
data = load_file(path)
|
| 48 |
-
samples += extract_ko(data)
|
| 49 |
-
return [[i, *samp] for i, samp in enumerate(samples)]
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
def req_imgs(url_info):
|
| 53 |
-
""" download imgs """
|
| 54 |
-
# request.get 요청
|
| 55 |
-
response = requests.get(url_info[1], headers={'User-agent': 'your bot 0.1'})
|
| 56 |
-
with open(f'{DATA_PATH}/img/{url_info[0]}.jpg', 'wb') as f:
|
| 57 |
-
f.write(response.content)
|
| 58 |
-
# print(f"{url_info[0]} is done.")
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
def down_imgs(urls):
|
| 62 |
-
with Pool(2) as p:
|
| 63 |
-
p.map(req_imgs, urls)
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
if __name__ == '__main__':
|
| 67 |
-
# path_list = glob.glob('/home/shared/dataset/wit')
|
| 68 |
-
path_list = glob.glob(f'{DATA_PATH}/info/*')
|
| 69 |
-
|
| 70 |
-
samples = rewrite_wit(path_list)
|
| 71 |
-
|
| 72 |
-
with open(f'{DATA_PATH}/wit_ko.csv', 'w') as f:
|
| 73 |
-
writer = csv.writer(f, delimiter='\t', quotechar='"')
|
| 74 |
-
writer.writerows(samples)
|
| 75 |
-
|
| 76 |
-
url_list = [[samp[0], samp[3]] for samp in samples]
|
| 77 |
-
down_imgs(url_list)
|
| 78 |
-
|
| 79 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
modeling_hybrid_clip.py
CHANGED
|
File without changes
|
models/coco_only/config.json
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"HybridCLIP"
|
| 4 |
+
],
|
| 5 |
+
"initializer_factor": 1.0,
|
| 6 |
+
"model_type": "hybrid-clip",
|
| 7 |
+
"projection_dim": 512,
|
| 8 |
+
"seed": 42,
|
| 9 |
+
"text_config": {
|
| 10 |
+
"_name_or_path": "",
|
| 11 |
+
"add_cross_attention": false,
|
| 12 |
+
"architectures": [
|
| 13 |
+
"RobertaForMaskedLM"
|
| 14 |
+
],
|
| 15 |
+
"attention_probs_dropout_prob": 0.1,
|
| 16 |
+
"bad_words_ids": null,
|
| 17 |
+
"bos_token_id": 0,
|
| 18 |
+
"chunk_size_feed_forward": 0,
|
| 19 |
+
"decoder_start_token_id": null,
|
| 20 |
+
"diversity_penalty": 0.0,
|
| 21 |
+
"do_sample": false,
|
| 22 |
+
"early_stopping": false,
|
| 23 |
+
"encoder_no_repeat_ngram_size": 0,
|
| 24 |
+
"eos_token_id": 2,
|
| 25 |
+
"finetuning_task": null,
|
| 26 |
+
"forced_bos_token_id": null,
|
| 27 |
+
"forced_eos_token_id": null,
|
| 28 |
+
"gradient_checkpointing": false,
|
| 29 |
+
"hidden_act": "gelu",
|
| 30 |
+
"hidden_dropout_prob": 0.1,
|
| 31 |
+
"hidden_size": 1024,
|
| 32 |
+
"id2label": {
|
| 33 |
+
"0": "LABEL_0",
|
| 34 |
+
"1": "LABEL_1"
|
| 35 |
+
},
|
| 36 |
+
"initializer_range": 0.02,
|
| 37 |
+
"intermediate_size": 4096,
|
| 38 |
+
"is_decoder": false,
|
| 39 |
+
"is_encoder_decoder": false,
|
| 40 |
+
"label2id": {
|
| 41 |
+
"LABEL_0": 0,
|
| 42 |
+
"LABEL_1": 1
|
| 43 |
+
},
|
| 44 |
+
"layer_norm_eps": 1e-05,
|
| 45 |
+
"length_penalty": 1.0,
|
| 46 |
+
"max_length": 20,
|
| 47 |
+
"max_position_embeddings": 512,
|
| 48 |
+
"min_length": 0,
|
| 49 |
+
"model_type": "roberta",
|
| 50 |
+
"no_repeat_ngram_size": 0,
|
| 51 |
+
"num_attention_heads": 16,
|
| 52 |
+
"num_beam_groups": 1,
|
| 53 |
+
"num_beams": 1,
|
| 54 |
+
"num_hidden_layers": 24,
|
| 55 |
+
"num_return_sequences": 1,
|
| 56 |
+
"output_attentions": false,
|
| 57 |
+
"output_hidden_states": false,
|
| 58 |
+
"output_scores": false,
|
| 59 |
+
"pad_token_id": 1,
|
| 60 |
+
"position_embedding_type": "absolute",
|
| 61 |
+
"prefix": null,
|
| 62 |
+
"problem_type": null,
|
| 63 |
+
"pruned_heads": {},
|
| 64 |
+
"remove_invalid_values": false,
|
| 65 |
+
"repetition_penalty": 1.0,
|
| 66 |
+
"return_dict": true,
|
| 67 |
+
"return_dict_in_generate": false,
|
| 68 |
+
"sep_token_id": null,
|
| 69 |
+
"task_specific_params": null,
|
| 70 |
+
"temperature": 1.0,
|
| 71 |
+
"tie_encoder_decoder": false,
|
| 72 |
+
"tie_word_embeddings": true,
|
| 73 |
+
"tokenizer_class": "BertTokenizer",
|
| 74 |
+
"top_k": 50,
|
| 75 |
+
"top_p": 1.0,
|
| 76 |
+
"torch_dtype": null,
|
| 77 |
+
"torchscript": false,
|
| 78 |
+
"transformers_version": "4.9.0.dev0",
|
| 79 |
+
"type_vocab_size": 1,
|
| 80 |
+
"use_bfloat16": false,
|
| 81 |
+
"use_cache": true,
|
| 82 |
+
"vocab_size": 32000
|
| 83 |
+
},
|
| 84 |
+
"transformers_version": null,
|
| 85 |
+
"vision_config": {
|
| 86 |
+
"_name_or_path": "",
|
| 87 |
+
"add_cross_attention": false,
|
| 88 |
+
"architectures": null,
|
| 89 |
+
"attention_dropout": 0.0,
|
| 90 |
+
"bad_words_ids": null,
|
| 91 |
+
"bos_token_id": null,
|
| 92 |
+
"chunk_size_feed_forward": 0,
|
| 93 |
+
"decoder_start_token_id": null,
|
| 94 |
+
"diversity_penalty": 0.0,
|
| 95 |
+
"do_sample": false,
|
| 96 |
+
"dropout": 0.0,
|
| 97 |
+
"early_stopping": false,
|
| 98 |
+
"encoder_no_repeat_ngram_size": 0,
|
| 99 |
+
"eos_token_id": null,
|
| 100 |
+
"finetuning_task": null,
|
| 101 |
+
"forced_bos_token_id": null,
|
| 102 |
+
"forced_eos_token_id": null,
|
| 103 |
+
"gradient_checkpointing": false,
|
| 104 |
+
"hidden_act": "quick_gelu",
|
| 105 |
+
"hidden_size": 768,
|
| 106 |
+
"id2label": {
|
| 107 |
+
"0": "LABEL_0",
|
| 108 |
+
"1": "LABEL_1"
|
| 109 |
+
},
|
| 110 |
+
"image_size": 224,
|
| 111 |
+
"initializer_factor": 1.0,
|
| 112 |
+
"initializer_range": 0.02,
|
| 113 |
+
"intermediate_size": 3072,
|
| 114 |
+
"is_decoder": false,
|
| 115 |
+
"is_encoder_decoder": false,
|
| 116 |
+
"label2id": {
|
| 117 |
+
"LABEL_0": 0,
|
| 118 |
+
"LABEL_1": 1
|
| 119 |
+
},
|
| 120 |
+
"layer_norm_eps": 1e-05,
|
| 121 |
+
"length_penalty": 1.0,
|
| 122 |
+
"max_length": 20,
|
| 123 |
+
"min_length": 0,
|
| 124 |
+
"model_type": "clip_vision_model",
|
| 125 |
+
"no_repeat_ngram_size": 0,
|
| 126 |
+
"num_attention_heads": 12,
|
| 127 |
+
"num_beam_groups": 1,
|
| 128 |
+
"num_beams": 1,
|
| 129 |
+
"num_hidden_layers": 12,
|
| 130 |
+
"num_return_sequences": 1,
|
| 131 |
+
"output_attentions": false,
|
| 132 |
+
"output_hidden_states": false,
|
| 133 |
+
"output_scores": false,
|
| 134 |
+
"pad_token_id": null,
|
| 135 |
+
"patch_size": 32,
|
| 136 |
+
"prefix": null,
|
| 137 |
+
"problem_type": null,
|
| 138 |
+
"pruned_heads": {},
|
| 139 |
+
"remove_invalid_values": false,
|
| 140 |
+
"repetition_penalty": 1.0,
|
| 141 |
+
"return_dict": true,
|
| 142 |
+
"return_dict_in_generate": false,
|
| 143 |
+
"sep_token_id": null,
|
| 144 |
+
"task_specific_params": null,
|
| 145 |
+
"temperature": 1.0,
|
| 146 |
+
"tie_encoder_decoder": false,
|
| 147 |
+
"tie_word_embeddings": true,
|
| 148 |
+
"tokenizer_class": null,
|
| 149 |
+
"top_k": 50,
|
| 150 |
+
"top_p": 1.0,
|
| 151 |
+
"torch_dtype": null,
|
| 152 |
+
"torchscript": false,
|
| 153 |
+
"transformers_version": "4.9.0.dev0",
|
| 154 |
+
"use_bfloat16": false
|
| 155 |
+
}
|
| 156 |
+
}
|
models/coco_only/flax_model.msgpack
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d1397edcc4c8f8e3c72fcb4a3cfdc742aa6ff727206f601e100e4df7398b2001
|
| 3 |
+
size 1700132358
|
requirements.txt
CHANGED
|
File without changes
|
run_hybrid_clip.py
CHANGED
|
@@ -31,24 +31,30 @@ from dataclasses import dataclass, field
|
|
| 31 |
from pathlib import Path
|
| 32 |
from typing import Callable, Optional
|
| 33 |
|
| 34 |
-
import torch
|
| 35 |
-
from torchvision.datasets import VisionDataset
|
| 36 |
-
from torchvision.io import ImageReadMode, read_image
|
| 37 |
-
from torchvision.transforms import CenterCrop, ConvertImageDtype, Normalize, Resize
|
| 38 |
-
from torchvision.transforms.functional import InterpolationMode
|
| 39 |
-
from tqdm import tqdm
|
| 40 |
-
|
| 41 |
import jax
|
| 42 |
import jax.numpy as jnp
|
| 43 |
import optax
|
|
|
|
| 44 |
import transformers
|
| 45 |
from flax import jax_utils
|
| 46 |
from flax.jax_utils import unreplicate
|
| 47 |
from flax.training import train_state
|
| 48 |
from flax.training.common_utils import get_metrics, shard, shard_prng_key
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
from modeling_hybrid_clip import FlaxHybridCLIP
|
| 50 |
-
from transformers import AutoTokenizer, HfArgumentParser, TrainingArguments, is_tensorboard_available, set_seed
|
| 51 |
-
|
| 52 |
|
| 53 |
logger = logging.getLogger(__name__)
|
| 54 |
|
|
@@ -59,7 +65,9 @@ if has_tensorboard:
|
|
| 59 |
from flax.metrics.tensorboard import SummaryWriter
|
| 60 |
except ImportError as ie:
|
| 61 |
has_tensorboard = False
|
| 62 |
-
print(
|
|
|
|
|
|
|
| 63 |
|
| 64 |
else:
|
| 65 |
print(
|
|
@@ -88,20 +96,33 @@ class ModelArguments:
|
|
| 88 |
)
|
| 89 |
from_pt: bool = field(
|
| 90 |
default=True,
|
| 91 |
-
metadata={
|
|
|
|
|
|
|
| 92 |
)
|
| 93 |
config_name: Optional[str] = field(
|
| 94 |
-
default=None,
|
|
|
|
|
|
|
|
|
|
| 95 |
)
|
| 96 |
tokenizer_name: Optional[str] = field(
|
| 97 |
-
default=None,
|
|
|
|
|
|
|
|
|
|
| 98 |
)
|
| 99 |
cache_dir: Optional[str] = field(
|
| 100 |
-
default=None,
|
|
|
|
|
|
|
|
|
|
| 101 |
)
|
| 102 |
use_fast_tokenizer: bool = field(
|
| 103 |
default=True,
|
| 104 |
-
metadata={
|
|
|
|
|
|
|
| 105 |
)
|
| 106 |
dtype: Optional[str] = field(
|
| 107 |
default="float32",
|
|
@@ -117,9 +138,12 @@ class DataTrainingArguments:
|
|
| 117 |
Arguments pertaining to what data we are going to input our model for training and eval.
|
| 118 |
"""
|
| 119 |
|
| 120 |
-
data_dir: Optional[str] = field(
|
|
|
|
|
|
|
| 121 |
train_file: Optional[str] = field(
|
| 122 |
-
default=None,
|
|
|
|
| 123 |
)
|
| 124 |
validation_file: Optional[str] = field(
|
| 125 |
default=None,
|
|
@@ -147,10 +171,12 @@ class DataTrainingArguments:
|
|
| 147 |
},
|
| 148 |
)
|
| 149 |
overwrite_cache: bool = field(
|
| 150 |
-
default=False,
|
|
|
|
| 151 |
)
|
| 152 |
overwrite_cache: bool = field(
|
| 153 |
-
default=False,
|
|
|
|
| 154 |
)
|
| 155 |
preprocessing_num_workers: Optional[int] = field(
|
| 156 |
default=None,
|
|
@@ -159,7 +185,9 @@ class DataTrainingArguments:
|
|
| 159 |
|
| 160 |
def __post_init__(self):
|
| 161 |
if self.train_file is None and self.validation_file is None:
|
| 162 |
-
raise ValueError(
|
|
|
|
|
|
|
| 163 |
else:
|
| 164 |
if self.train_file is not None:
|
| 165 |
extension = self.train_file.split(".")[-1]
|
|
@@ -169,87 +197,13 @@ class DataTrainingArguments:
|
|
| 169 |
assert extension == "json", "`validation_file` should be a json file."
|
| 170 |
|
| 171 |
|
| 172 |
-
# We use torchvision for faster image pre-processing.
|
| 173 |
-
# We need to ensure faster processing speed as it can become a bottleneck on TPU
|
| 174 |
-
class Transform(torch.nn.Module):
|
| 175 |
-
def __init__(self, image_size):
|
| 176 |
-
super().__init__()
|
| 177 |
-
self.transforms = torch.nn.Sequential(
|
| 178 |
-
Resize([image_size], interpolation=InterpolationMode.BICUBIC),
|
| 179 |
-
CenterCrop(image_size),
|
| 180 |
-
ConvertImageDtype(torch.float),
|
| 181 |
-
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
|
| 182 |
-
)
|
| 183 |
-
|
| 184 |
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 185 |
-
with torch.no_grad():
|
| 186 |
-
x = self.transforms(x)
|
| 187 |
-
return x
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
class ImageTextDataset(VisionDataset):
|
| 191 |
-
"""
|
| 192 |
-
Dtaset for loading image-text data for tasks like CLIP training, Image Captioning.
|
| 193 |
-
Args:
|
| 194 |
-
root: (string): The root path where the dataset is stored
|
| 195 |
-
file_path: (string): Path to the file containing the image_paths and associated captions.
|
| 196 |
-
The expected format is jsonlines where each line is a json object containing to keys.
|
| 197 |
-
`image_path`: The path to the image.
|
| 198 |
-
`captions`: An `array` of captions.
|
| 199 |
-
transform (callable, optional): A function/transform that takes in an PIL image
|
| 200 |
-
and returns a transformed version. E.g, ``transforms.ToTensor``
|
| 201 |
-
target_transform (callable, optional): A function/transform that takes in the
|
| 202 |
-
target and transforms it.
|
| 203 |
-
transforms (callable, optional): A function/transform that takes input sample and its target as entry
|
| 204 |
-
and returns a transformed version.
|
| 205 |
-
"""
|
| 206 |
-
|
| 207 |
-
def __init__(
|
| 208 |
-
self,
|
| 209 |
-
root: str,
|
| 210 |
-
file_path: str,
|
| 211 |
-
captions_per_image=2,
|
| 212 |
-
transform: Optional[Callable] = None,
|
| 213 |
-
target_transform: Optional[Callable] = None,
|
| 214 |
-
transforms: Optional[Callable] = None,
|
| 215 |
-
):
|
| 216 |
-
super().__init__(root, transforms, transform, target_transform)
|
| 217 |
-
|
| 218 |
-
with open(file_path, "r") as f:
|
| 219 |
-
examples = [json.loads(line) for line in f.readlines()]
|
| 220 |
-
|
| 221 |
-
self.captions = []
|
| 222 |
-
self.image_paths = []
|
| 223 |
-
|
| 224 |
-
for example in examples:
|
| 225 |
-
self.captions.extend(example["captions"][:captions_per_image])
|
| 226 |
-
self.image_paths.extend([example["image_path"]] * captions_per_image)
|
| 227 |
-
|
| 228 |
-
def _load_image(self, idx: int):
|
| 229 |
-
path = self.image_paths[idx]
|
| 230 |
-
return read_image(path, mode=ImageReadMode.RGB)
|
| 231 |
-
|
| 232 |
-
def _load_target(self, idx):
|
| 233 |
-
return self.captions[idx]
|
| 234 |
-
|
| 235 |
-
def __getitem__(self, index: int):
|
| 236 |
-
image = self._load_image(index)
|
| 237 |
-
target = self._load_target(index)
|
| 238 |
-
|
| 239 |
-
if self.transforms is not None:
|
| 240 |
-
image, target = self.transforms(image, target)
|
| 241 |
-
|
| 242 |
-
return image, target
|
| 243 |
-
|
| 244 |
-
def __len__(self) -> int:
|
| 245 |
-
return len(self.captions)
|
| 246 |
-
|
| 247 |
-
|
| 248 |
class TrainState(train_state.TrainState):
|
| 249 |
dropout_rng: jnp.ndarray
|
| 250 |
|
| 251 |
def replicate(self):
|
| 252 |
-
return jax_utils.replicate(self).replace(
|
|
|
|
|
|
|
| 253 |
|
| 254 |
|
| 255 |
def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
|
|
@@ -266,25 +220,39 @@ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
|
|
| 266 |
|
| 267 |
|
| 268 |
def create_learning_rate_fn(
|
| 269 |
-
train_ds_size: int,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 270 |
) -> Callable[[int], jnp.array]:
|
| 271 |
"""Returns a linear warmup, linear_decay learning rate function."""
|
| 272 |
steps_per_epoch = train_ds_size // train_batch_size
|
| 273 |
num_train_steps = steps_per_epoch * num_train_epochs
|
| 274 |
-
warmup_fn = optax.linear_schedule(
|
|
|
|
|
|
|
| 275 |
decay_fn = optax.linear_schedule(
|
| 276 |
-
init_value=learning_rate,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 277 |
)
|
| 278 |
-
schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
|
| 279 |
return schedule_fn
|
| 280 |
|
| 281 |
|
| 282 |
def main():
|
| 283 |
-
parser = HfArgumentParser(
|
|
|
|
|
|
|
| 284 |
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
| 285 |
# If we pass only one argument to the script and it's the path to a json file,
|
| 286 |
# let's parse it to get our arguments.
|
| 287 |
-
model_args, data_args, training_args = parser.parse_json_file(
|
|
|
|
|
|
|
| 288 |
else:
|
| 289 |
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
| 290 |
|
|
@@ -317,11 +285,15 @@ def main():
|
|
| 317 |
|
| 318 |
if model_args.tokenizer_name:
|
| 319 |
tokenizer = AutoTokenizer.from_pretrained(
|
| 320 |
-
model_args.tokenizer_name,
|
|
|
|
|
|
|
| 321 |
)
|
| 322 |
elif model_args.text_model_name_or_path:
|
| 323 |
tokenizer = AutoTokenizer.from_pretrained(
|
| 324 |
-
model_args.text_model_name_or_path,
|
|
|
|
|
|
|
| 325 |
)
|
| 326 |
else:
|
| 327 |
raise ValueError(
|
|
@@ -349,29 +321,40 @@ def main():
|
|
| 349 |
train_dataset = ImageTextDataset(
|
| 350 |
data_args.data_dir,
|
| 351 |
data_args.train_file,
|
| 352 |
-
captions_per_image=
|
| 353 |
transform=preprocess,
|
| 354 |
)
|
| 355 |
|
| 356 |
eval_dataset = ImageTextDataset(
|
| 357 |
data_args.data_dir,
|
| 358 |
data_args.validation_file,
|
| 359 |
-
captions_per_image=
|
| 360 |
transform=preprocess,
|
| 361 |
)
|
| 362 |
|
| 363 |
# Store some constant
|
| 364 |
num_epochs = int(training_args.num_train_epochs)
|
| 365 |
-
train_batch_size =
|
|
|
|
|
|
|
| 366 |
eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
|
| 367 |
steps_per_epoch = len(train_dataset) // train_batch_size
|
| 368 |
total_train_steps = steps_per_epoch * num_epochs
|
| 369 |
|
| 370 |
# Use collate function to tokenizer the text and convert the processed images to numpy
|
| 371 |
def collate_fn(examples):
|
| 372 |
-
pixel_values =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 373 |
captions = [example[1] for example in examples]
|
| 374 |
-
inputs = tokenizer(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 375 |
|
| 376 |
batch = {
|
| 377 |
"pixel_values": pixel_values,
|
|
@@ -404,7 +387,9 @@ def main():
|
|
| 404 |
|
| 405 |
# Enable tensorboard only on the master node
|
| 406 |
if has_tensorboard and jax.process_index() == 0:
|
| 407 |
-
summary_writer = SummaryWriter(
|
|
|
|
|
|
|
| 408 |
|
| 409 |
# Initialize our training
|
| 410 |
rng = jax.random.PRNGKey(training_args.seed)
|
|
@@ -429,7 +414,9 @@ def main():
|
|
| 429 |
)
|
| 430 |
|
| 431 |
# Setup train state
|
| 432 |
-
state = TrainState.create(
|
|
|
|
|
|
|
| 433 |
|
| 434 |
def cross_entropy(logits, axis):
|
| 435 |
logprobs = jax.nn.log_softmax(logits, axis=axis)
|
|
@@ -438,7 +425,9 @@ def main():
|
|
| 438 |
return ce
|
| 439 |
|
| 440 |
def clip_loss(similarity):
|
| 441 |
-
loss = (
|
|
|
|
|
|
|
| 442 |
return loss
|
| 443 |
|
| 444 |
# Define gradient update step fn
|
|
@@ -446,7 +435,9 @@ def main():
|
|
| 446 |
dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
|
| 447 |
|
| 448 |
def compute_loss(params):
|
| 449 |
-
logits = state.apply_fn(
|
|
|
|
|
|
|
| 450 |
loss = clip_loss(logits)
|
| 451 |
return loss
|
| 452 |
|
|
@@ -456,7 +447,10 @@ def main():
|
|
| 456 |
|
| 457 |
new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)
|
| 458 |
|
| 459 |
-
metrics = {
|
|
|
|
|
|
|
|
|
|
| 460 |
metrics = jax.lax.pmean(metrics, axis_name="batch")
|
| 461 |
|
| 462 |
return new_state, metrics
|
|
@@ -481,8 +475,12 @@ def main():
|
|
| 481 |
logger.info("***** Running training *****")
|
| 482 |
logger.info(f" Num examples = {len(train_dataset)}")
|
| 483 |
logger.info(f" Num Epochs = {num_epochs}")
|
| 484 |
-
logger.info(
|
| 485 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 486 |
logger.info(f" Total optimization steps = {total_train_steps}")
|
| 487 |
|
| 488 |
train_time = 0
|
|
@@ -499,7 +497,9 @@ def main():
|
|
| 499 |
train_metrics = []
|
| 500 |
|
| 501 |
steps_per_epoch = len(train_dataset) // train_batch_size
|
| 502 |
-
train_step_progress_bar = tqdm(
|
|
|
|
|
|
|
| 503 |
# train
|
| 504 |
for batch in train_loader:
|
| 505 |
batch = shard(batch)
|
|
@@ -520,7 +520,9 @@ def main():
|
|
| 520 |
# ======================== Evaluating ==============================
|
| 521 |
eval_metrics = []
|
| 522 |
eval_steps = len(eval_dataset) // eval_batch_size
|
| 523 |
-
eval_step_progress_bar = tqdm(
|
|
|
|
|
|
|
| 524 |
for batch in eval_loader:
|
| 525 |
# Model forward
|
| 526 |
batch = shard(batch)
|
|
@@ -536,14 +538,18 @@ def main():
|
|
| 536 |
|
| 537 |
# Print metrics and update progress bar
|
| 538 |
eval_step_progress_bar.close()
|
| 539 |
-
desc =
|
|
|
|
|
|
|
| 540 |
epochs.write(desc)
|
| 541 |
epochs.desc = desc
|
| 542 |
|
| 543 |
# Save metrics
|
| 544 |
if has_tensorboard and jax.process_index() == 0:
|
| 545 |
cur_step = epoch * (len(train_dataset) // train_batch_size)
|
| 546 |
-
write_metric(
|
|
|
|
|
|
|
| 547 |
|
| 548 |
# save checkpoint after each epoch and push checkpoint to the hub
|
| 549 |
if jax.process_index() == 0:
|
|
@@ -557,4 +563,4 @@ def main():
|
|
| 557 |
|
| 558 |
|
| 559 |
if __name__ == "__main__":
|
| 560 |
-
main()
|
|
|
|
| 31 |
from pathlib import Path
|
| 32 |
from typing import Callable, Optional
|
| 33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
import jax
|
| 35 |
import jax.numpy as jnp
|
| 36 |
import optax
|
| 37 |
+
import torch
|
| 38 |
import transformers
|
| 39 |
from flax import jax_utils
|
| 40 |
from flax.jax_utils import unreplicate
|
| 41 |
from flax.training import train_state
|
| 42 |
from flax.training.common_utils import get_metrics, shard, shard_prng_key
|
| 43 |
+
from torchvision.datasets import VisionDataset
|
| 44 |
+
from torchvision.io import ImageReadMode, read_image
|
| 45 |
+
from torchvision.transforms import CenterCrop, ConvertImageDtype, Normalize, Resize
|
| 46 |
+
from torchvision.transforms.functional import InterpolationMode
|
| 47 |
+
from tqdm import tqdm
|
| 48 |
+
from transformers import (
|
| 49 |
+
AutoTokenizer,
|
| 50 |
+
HfArgumentParser,
|
| 51 |
+
TrainingArguments,
|
| 52 |
+
is_tensorboard_available,
|
| 53 |
+
set_seed,
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
from dataloader import ImageTextDataset, Transform
|
| 57 |
from modeling_hybrid_clip import FlaxHybridCLIP
|
|
|
|
|
|
|
| 58 |
|
| 59 |
logger = logging.getLogger(__name__)
|
| 60 |
|
|
|
|
| 65 |
from flax.metrics.tensorboard import SummaryWriter
|
| 66 |
except ImportError as ie:
|
| 67 |
has_tensorboard = False
|
| 68 |
+
print(
|
| 69 |
+
f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
|
| 70 |
+
)
|
| 71 |
|
| 72 |
else:
|
| 73 |
print(
|
|
|
|
| 96 |
)
|
| 97 |
from_pt: bool = field(
|
| 98 |
default=True,
|
| 99 |
+
metadata={
|
| 100 |
+
"help": "whether to load the text and vision model using PyTorch checkpoints."
|
| 101 |
+
},
|
| 102 |
)
|
| 103 |
config_name: Optional[str] = field(
|
| 104 |
+
default=None,
|
| 105 |
+
metadata={
|
| 106 |
+
"help": "Pretrained config name or path if not the same as model_name"
|
| 107 |
+
},
|
| 108 |
)
|
| 109 |
tokenizer_name: Optional[str] = field(
|
| 110 |
+
default=None,
|
| 111 |
+
metadata={
|
| 112 |
+
"help": "Pretrained tokenizer name or path if not the same as model_name"
|
| 113 |
+
},
|
| 114 |
)
|
| 115 |
cache_dir: Optional[str] = field(
|
| 116 |
+
default=None,
|
| 117 |
+
metadata={
|
| 118 |
+
"help": "Where do you want to store the pretrained models downloaded from s3"
|
| 119 |
+
},
|
| 120 |
)
|
| 121 |
use_fast_tokenizer: bool = field(
|
| 122 |
default=True,
|
| 123 |
+
metadata={
|
| 124 |
+
"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."
|
| 125 |
+
},
|
| 126 |
)
|
| 127 |
dtype: Optional[str] = field(
|
| 128 |
default="float32",
|
|
|
|
| 138 |
Arguments pertaining to what data we are going to input our model for training and eval.
|
| 139 |
"""
|
| 140 |
|
| 141 |
+
data_dir: Optional[str] = field(
|
| 142 |
+
default=None, metadata={"help": "The data directory containing input files."}
|
| 143 |
+
)
|
| 144 |
train_file: Optional[str] = field(
|
| 145 |
+
default=None,
|
| 146 |
+
metadata={"help": "The input training data file (a jsonlines file)."},
|
| 147 |
)
|
| 148 |
validation_file: Optional[str] = field(
|
| 149 |
default=None,
|
|
|
|
| 171 |
},
|
| 172 |
)
|
| 173 |
overwrite_cache: bool = field(
|
| 174 |
+
default=False,
|
| 175 |
+
metadata={"help": "Overwrite the cached training and evaluation sets"},
|
| 176 |
)
|
| 177 |
overwrite_cache: bool = field(
|
| 178 |
+
default=False,
|
| 179 |
+
metadata={"help": "Overwrite the cached training and evaluation sets"},
|
| 180 |
)
|
| 181 |
preprocessing_num_workers: Optional[int] = field(
|
| 182 |
default=None,
|
|
|
|
| 185 |
|
| 186 |
def __post_init__(self):
|
| 187 |
if self.train_file is None and self.validation_file is None:
|
| 188 |
+
raise ValueError(
|
| 189 |
+
"Need either a dataset name or a training/validation file."
|
| 190 |
+
)
|
| 191 |
else:
|
| 192 |
if self.train_file is not None:
|
| 193 |
extension = self.train_file.split(".")[-1]
|
|
|
|
| 197 |
assert extension == "json", "`validation_file` should be a json file."
|
| 198 |
|
| 199 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 200 |
class TrainState(train_state.TrainState):
|
| 201 |
dropout_rng: jnp.ndarray
|
| 202 |
|
| 203 |
def replicate(self):
|
| 204 |
+
return jax_utils.replicate(self).replace(
|
| 205 |
+
dropout_rng=shard_prng_key(self.dropout_rng)
|
| 206 |
+
)
|
| 207 |
|
| 208 |
|
| 209 |
def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
|
|
|
|
| 220 |
|
| 221 |
|
| 222 |
def create_learning_rate_fn(
|
| 223 |
+
train_ds_size: int,
|
| 224 |
+
train_batch_size: int,
|
| 225 |
+
num_train_epochs: int,
|
| 226 |
+
num_warmup_steps: int,
|
| 227 |
+
learning_rate: float,
|
| 228 |
) -> Callable[[int], jnp.array]:
|
| 229 |
"""Returns a linear warmup, linear_decay learning rate function."""
|
| 230 |
steps_per_epoch = train_ds_size // train_batch_size
|
| 231 |
num_train_steps = steps_per_epoch * num_train_epochs
|
| 232 |
+
warmup_fn = optax.linear_schedule(
|
| 233 |
+
init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps
|
| 234 |
+
)
|
| 235 |
decay_fn = optax.linear_schedule(
|
| 236 |
+
init_value=learning_rate,
|
| 237 |
+
end_value=0,
|
| 238 |
+
transition_steps=num_train_steps - num_warmup_steps,
|
| 239 |
+
)
|
| 240 |
+
schedule_fn = optax.join_schedules(
|
| 241 |
+
schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps]
|
| 242 |
)
|
|
|
|
| 243 |
return schedule_fn
|
| 244 |
|
| 245 |
|
| 246 |
def main():
|
| 247 |
+
parser = HfArgumentParser(
|
| 248 |
+
(ModelArguments, DataTrainingArguments, TrainingArguments)
|
| 249 |
+
)
|
| 250 |
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
| 251 |
# If we pass only one argument to the script and it's the path to a json file,
|
| 252 |
# let's parse it to get our arguments.
|
| 253 |
+
model_args, data_args, training_args = parser.parse_json_file(
|
| 254 |
+
json_file=os.path.abspath(sys.argv[1])
|
| 255 |
+
)
|
| 256 |
else:
|
| 257 |
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
| 258 |
|
|
|
|
| 285 |
|
| 286 |
if model_args.tokenizer_name:
|
| 287 |
tokenizer = AutoTokenizer.from_pretrained(
|
| 288 |
+
model_args.tokenizer_name,
|
| 289 |
+
cache_dir=model_args.cache_dir,
|
| 290 |
+
use_fast=model_args.use_fast_tokenizer,
|
| 291 |
)
|
| 292 |
elif model_args.text_model_name_or_path:
|
| 293 |
tokenizer = AutoTokenizer.from_pretrained(
|
| 294 |
+
model_args.text_model_name_or_path,
|
| 295 |
+
cache_dir=model_args.cache_dir,
|
| 296 |
+
use_fast=model_args.use_fast_tokenizer,
|
| 297 |
)
|
| 298 |
else:
|
| 299 |
raise ValueError(
|
|
|
|
| 321 |
train_dataset = ImageTextDataset(
|
| 322 |
data_args.data_dir,
|
| 323 |
data_args.train_file,
|
| 324 |
+
captions_per_image=5,
|
| 325 |
transform=preprocess,
|
| 326 |
)
|
| 327 |
|
| 328 |
eval_dataset = ImageTextDataset(
|
| 329 |
data_args.data_dir,
|
| 330 |
data_args.validation_file,
|
| 331 |
+
captions_per_image=5,
|
| 332 |
transform=preprocess,
|
| 333 |
)
|
| 334 |
|
| 335 |
# Store some constant
|
| 336 |
num_epochs = int(training_args.num_train_epochs)
|
| 337 |
+
train_batch_size = (
|
| 338 |
+
int(training_args.per_device_train_batch_size) * jax.device_count()
|
| 339 |
+
)
|
| 340 |
eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
|
| 341 |
steps_per_epoch = len(train_dataset) // train_batch_size
|
| 342 |
total_train_steps = steps_per_epoch * num_epochs
|
| 343 |
|
| 344 |
# Use collate function to tokenizer the text and convert the processed images to numpy
|
| 345 |
def collate_fn(examples):
|
| 346 |
+
pixel_values = (
|
| 347 |
+
torch.stack([example[0] for example in examples])
|
| 348 |
+
.permute(0, 2, 3, 1)
|
| 349 |
+
.numpy()
|
| 350 |
+
)
|
| 351 |
captions = [example[1] for example in examples]
|
| 352 |
+
inputs = tokenizer(
|
| 353 |
+
captions,
|
| 354 |
+
max_length=data_args.max_seq_length,
|
| 355 |
+
padding="max_length",
|
| 356 |
+
return_tensors="np",
|
| 357 |
+
)
|
| 358 |
|
| 359 |
batch = {
|
| 360 |
"pixel_values": pixel_values,
|
|
|
|
| 387 |
|
| 388 |
# Enable tensorboard only on the master node
|
| 389 |
if has_tensorboard and jax.process_index() == 0:
|
| 390 |
+
summary_writer = SummaryWriter(
|
| 391 |
+
log_dir=Path(training_args.output_dir).joinpath("logs").as_posix()
|
| 392 |
+
)
|
| 393 |
|
| 394 |
# Initialize our training
|
| 395 |
rng = jax.random.PRNGKey(training_args.seed)
|
|
|
|
| 414 |
)
|
| 415 |
|
| 416 |
# Setup train state
|
| 417 |
+
state = TrainState.create(
|
| 418 |
+
apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng
|
| 419 |
+
)
|
| 420 |
|
| 421 |
def cross_entropy(logits, axis):
|
| 422 |
logprobs = jax.nn.log_softmax(logits, axis=axis)
|
|
|
|
| 425 |
return ce
|
| 426 |
|
| 427 |
def clip_loss(similarity):
|
| 428 |
+
loss = (
|
| 429 |
+
cross_entropy(similarity, axis=0) + cross_entropy(similarity, axis=1)
|
| 430 |
+
) / 2
|
| 431 |
return loss
|
| 432 |
|
| 433 |
# Define gradient update step fn
|
|
|
|
| 435 |
dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
|
| 436 |
|
| 437 |
def compute_loss(params):
|
| 438 |
+
logits = state.apply_fn(
|
| 439 |
+
**batch, params=params, dropout_rng=dropout_rng, train=True
|
| 440 |
+
)[0]
|
| 441 |
loss = clip_loss(logits)
|
| 442 |
return loss
|
| 443 |
|
|
|
|
| 447 |
|
| 448 |
new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)
|
| 449 |
|
| 450 |
+
metrics = {
|
| 451 |
+
"loss": loss,
|
| 452 |
+
"learning_rate": linear_decay_lr_schedule_fn(state.step),
|
| 453 |
+
}
|
| 454 |
metrics = jax.lax.pmean(metrics, axis_name="batch")
|
| 455 |
|
| 456 |
return new_state, metrics
|
|
|
|
| 475 |
logger.info("***** Running training *****")
|
| 476 |
logger.info(f" Num examples = {len(train_dataset)}")
|
| 477 |
logger.info(f" Num Epochs = {num_epochs}")
|
| 478 |
+
logger.info(
|
| 479 |
+
f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}"
|
| 480 |
+
)
|
| 481 |
+
logger.info(
|
| 482 |
+
f" Total train batch size (w. parallel & distributed) = {train_batch_size}"
|
| 483 |
+
)
|
| 484 |
logger.info(f" Total optimization steps = {total_train_steps}")
|
| 485 |
|
| 486 |
train_time = 0
|
|
|
|
| 497 |
train_metrics = []
|
| 498 |
|
| 499 |
steps_per_epoch = len(train_dataset) // train_batch_size
|
| 500 |
+
train_step_progress_bar = tqdm(
|
| 501 |
+
total=steps_per_epoch, desc="Training...", position=1, leave=False
|
| 502 |
+
)
|
| 503 |
# train
|
| 504 |
for batch in train_loader:
|
| 505 |
batch = shard(batch)
|
|
|
|
| 520 |
# ======================== Evaluating ==============================
|
| 521 |
eval_metrics = []
|
| 522 |
eval_steps = len(eval_dataset) // eval_batch_size
|
| 523 |
+
eval_step_progress_bar = tqdm(
|
| 524 |
+
total=eval_steps, desc="Evaluating...", position=2, leave=False
|
| 525 |
+
)
|
| 526 |
for batch in eval_loader:
|
| 527 |
# Model forward
|
| 528 |
batch = shard(batch)
|
|
|
|
| 538 |
|
| 539 |
# Print metrics and update progress bar
|
| 540 |
eval_step_progress_bar.close()
|
| 541 |
+
desc = (
|
| 542 |
+
f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']})"
|
| 543 |
+
)
|
| 544 |
epochs.write(desc)
|
| 545 |
epochs.desc = desc
|
| 546 |
|
| 547 |
# Save metrics
|
| 548 |
if has_tensorboard and jax.process_index() == 0:
|
| 549 |
cur_step = epoch * (len(train_dataset) // train_batch_size)
|
| 550 |
+
write_metric(
|
| 551 |
+
summary_writer, train_metrics, eval_metrics, train_time, cur_step
|
| 552 |
+
)
|
| 553 |
|
| 554 |
# save checkpoint after each epoch and push checkpoint to the hub
|
| 555 |
if jax.process_index() == 0:
|
|
|
|
| 563 |
|
| 564 |
|
| 565 |
if __name__ == "__main__":
|
| 566 |
+
main()
|
run_hybrid_clip_en.py
DELETED
|
@@ -1,570 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python
|
| 2 |
-
# coding=utf-8
|
| 3 |
-
# Copyright 2021 The HuggingFace Team All rights reserved.
|
| 4 |
-
#
|
| 5 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
-
# you may not use this file except in compliance with the License.
|
| 7 |
-
# You may obtain a copy of the License at
|
| 8 |
-
#
|
| 9 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
-
#
|
| 11 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
-
# See the License for the specific language governing permissions and
|
| 15 |
-
# limitations under the License.
|
| 16 |
-
"""
|
| 17 |
-
Training a CLIP like dual encoder models using text and vision encoders in the library.
|
| 18 |
-
The script can be used to train CLIP like models for languages other than english by using
|
| 19 |
-
a text encoder pre-trained in the desired language. Currently this script support the following vision
|
| 20 |
-
and text models:
|
| 21 |
-
Vision models: ViT(https://huggingface.co/models?filter=vit), CLIP (https://huggingface.co/models?filter=clip)
|
| 22 |
-
Text models: BERT, ROBERTa (https://huggingface.co/models?filter=masked-lm)
|
| 23 |
-
"""
|
| 24 |
-
|
| 25 |
-
import json
|
| 26 |
-
import logging
|
| 27 |
-
import os
|
| 28 |
-
import sys
|
| 29 |
-
import time
|
| 30 |
-
from dataclasses import dataclass, field
|
| 31 |
-
from pathlib import Path
|
| 32 |
-
from typing import Callable, Optional
|
| 33 |
-
|
| 34 |
-
import jax
|
| 35 |
-
import jax.numpy as jnp
|
| 36 |
-
import optax
|
| 37 |
-
import torch
|
| 38 |
-
import transformers
|
| 39 |
-
from flax import jax_utils
|
| 40 |
-
from flax.jax_utils import unreplicate
|
| 41 |
-
from flax.training import train_state
|
| 42 |
-
from flax.training.common_utils import get_metrics, shard, shard_prng_key
|
| 43 |
-
from pororo import Pororo
|
| 44 |
-
from torchvision.datasets import VisionDataset
|
| 45 |
-
from torchvision.io import ImageReadMode, read_image
|
| 46 |
-
from torchvision.transforms import CenterCrop, ConvertImageDtype, Normalize, Resize
|
| 47 |
-
from torchvision.transforms.functional import InterpolationMode
|
| 48 |
-
from tqdm import tqdm
|
| 49 |
-
from transformers import (
|
| 50 |
-
AutoTokenizer,
|
| 51 |
-
HfArgumentParser,
|
| 52 |
-
TrainingArguments,
|
| 53 |
-
is_tensorboard_available,
|
| 54 |
-
set_seed,
|
| 55 |
-
)
|
| 56 |
-
|
| 57 |
-
from dataloader import ImageTextDataset, Transform
|
| 58 |
-
from modeling_hybrid_clip import FlaxHybridCLIP
|
| 59 |
-
|
| 60 |
-
logger = logging.getLogger(__name__)
|
| 61 |
-
|
| 62 |
-
# Cache the result
|
| 63 |
-
has_tensorboard = is_tensorboard_available()
|
| 64 |
-
if has_tensorboard:
|
| 65 |
-
try:
|
| 66 |
-
from flax.metrics.tensorboard import SummaryWriter
|
| 67 |
-
except ImportError as ie:
|
| 68 |
-
has_tensorboard = False
|
| 69 |
-
print(
|
| 70 |
-
f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
|
| 71 |
-
)
|
| 72 |
-
|
| 73 |
-
else:
|
| 74 |
-
print(
|
| 75 |
-
"Unable to display metrics through TensorBoard because the package is not installed: "
|
| 76 |
-
"Please run pip install tensorboard to enable."
|
| 77 |
-
)
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
@dataclass
|
| 81 |
-
class ModelArguments:
|
| 82 |
-
"""
|
| 83 |
-
Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
|
| 84 |
-
"""
|
| 85 |
-
|
| 86 |
-
text_model_name_or_path: str = field(
|
| 87 |
-
metadata={
|
| 88 |
-
"help": "The text model checkpoint for weights initialization."
|
| 89 |
-
"Don't set if you want to train a model from scratch."
|
| 90 |
-
},
|
| 91 |
-
)
|
| 92 |
-
vision_model_name_or_path: str = field(
|
| 93 |
-
metadata={
|
| 94 |
-
"help": "The vision model checkpoint for weights initialization."
|
| 95 |
-
"Don't set if you want to train a model from scratch."
|
| 96 |
-
},
|
| 97 |
-
)
|
| 98 |
-
from_pt: bool = field(
|
| 99 |
-
default=True,
|
| 100 |
-
metadata={
|
| 101 |
-
"help": "whether to load the text and vision model using PyTorch checkpoints."
|
| 102 |
-
},
|
| 103 |
-
)
|
| 104 |
-
config_name: Optional[str] = field(
|
| 105 |
-
default=None,
|
| 106 |
-
metadata={
|
| 107 |
-
"help": "Pretrained config name or path if not the same as model_name"
|
| 108 |
-
},
|
| 109 |
-
)
|
| 110 |
-
tokenizer_name: Optional[str] = field(
|
| 111 |
-
default=None,
|
| 112 |
-
metadata={
|
| 113 |
-
"help": "Pretrained tokenizer name or path if not the same as model_name"
|
| 114 |
-
},
|
| 115 |
-
)
|
| 116 |
-
cache_dir: Optional[str] = field(
|
| 117 |
-
default=None,
|
| 118 |
-
metadata={
|
| 119 |
-
"help": "Where do you want to store the pretrained models downloaded from s3"
|
| 120 |
-
},
|
| 121 |
-
)
|
| 122 |
-
use_fast_tokenizer: bool = field(
|
| 123 |
-
default=True,
|
| 124 |
-
metadata={
|
| 125 |
-
"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."
|
| 126 |
-
},
|
| 127 |
-
)
|
| 128 |
-
dtype: Optional[str] = field(
|
| 129 |
-
default="float32",
|
| 130 |
-
metadata={
|
| 131 |
-
"help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
|
| 132 |
-
},
|
| 133 |
-
)
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
@dataclass
|
| 137 |
-
class DataTrainingArguments:
|
| 138 |
-
"""
|
| 139 |
-
Arguments pertaining to what data we are going to input our model for training and eval.
|
| 140 |
-
"""
|
| 141 |
-
|
| 142 |
-
data_dir: Optional[str] = field(
|
| 143 |
-
default=None, metadata={"help": "The data directory containing input files."}
|
| 144 |
-
)
|
| 145 |
-
train_file: Optional[str] = field(
|
| 146 |
-
default=None,
|
| 147 |
-
metadata={"help": "The input training data file (a jsonlines file)."},
|
| 148 |
-
)
|
| 149 |
-
validation_file: Optional[str] = field(
|
| 150 |
-
default=None,
|
| 151 |
-
metadata={"help": "An optional input evaluation data file (a jsonlines file)."},
|
| 152 |
-
)
|
| 153 |
-
max_seq_length: Optional[int] = field(
|
| 154 |
-
default=72,
|
| 155 |
-
metadata={
|
| 156 |
-
"help": "The maximum total input sequence length after tokenization. Sequences longer "
|
| 157 |
-
"than this will be truncated, sequences shorter will be padded."
|
| 158 |
-
},
|
| 159 |
-
)
|
| 160 |
-
max_train_samples: Optional[int] = field(
|
| 161 |
-
default=None,
|
| 162 |
-
metadata={
|
| 163 |
-
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
|
| 164 |
-
"value if set."
|
| 165 |
-
},
|
| 166 |
-
)
|
| 167 |
-
max_eval_samples: Optional[int] = field(
|
| 168 |
-
default=None,
|
| 169 |
-
metadata={
|
| 170 |
-
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
| 171 |
-
"value if set."
|
| 172 |
-
},
|
| 173 |
-
)
|
| 174 |
-
overwrite_cache: bool = field(
|
| 175 |
-
default=False,
|
| 176 |
-
metadata={"help": "Overwrite the cached training and evaluation sets"},
|
| 177 |
-
)
|
| 178 |
-
overwrite_cache: bool = field(
|
| 179 |
-
default=False,
|
| 180 |
-
metadata={"help": "Overwrite the cached training and evaluation sets"},
|
| 181 |
-
)
|
| 182 |
-
preprocessing_num_workers: Optional[int] = field(
|
| 183 |
-
default=None,
|
| 184 |
-
metadata={"help": "The number of processes to use for the preprocessing."},
|
| 185 |
-
)
|
| 186 |
-
|
| 187 |
-
def __post_init__(self):
|
| 188 |
-
if self.train_file is None and self.validation_file is None:
|
| 189 |
-
raise ValueError(
|
| 190 |
-
"Need either a dataset name or a training/validation file."
|
| 191 |
-
)
|
| 192 |
-
else:
|
| 193 |
-
if self.train_file is not None:
|
| 194 |
-
extension = self.train_file.split(".")[-1]
|
| 195 |
-
assert extension == "json", "`train_file` should be a json file."
|
| 196 |
-
if self.validation_file is not None:
|
| 197 |
-
extension = self.validation_file.split(".")[-1]
|
| 198 |
-
assert extension == "json", "`validation_file` should be a json file."
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
class TrainState(train_state.TrainState):
|
| 202 |
-
dropout_rng: jnp.ndarray
|
| 203 |
-
|
| 204 |
-
def replicate(self):
|
| 205 |
-
return jax_utils.replicate(self).replace(
|
| 206 |
-
dropout_rng=shard_prng_key(self.dropout_rng)
|
| 207 |
-
)
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
|
| 211 |
-
summary_writer.scalar("train_time", train_time, step)
|
| 212 |
-
|
| 213 |
-
train_metrics = get_metrics(train_metrics)
|
| 214 |
-
for key, vals in train_metrics.items():
|
| 215 |
-
tag = f"train_{key}"
|
| 216 |
-
for i, val in enumerate(vals):
|
| 217 |
-
summary_writer.scalar(tag, val, step - len(vals) + i + 1)
|
| 218 |
-
|
| 219 |
-
for metric_name, value in eval_metrics.items():
|
| 220 |
-
summary_writer.scalar(f"eval_{metric_name}", value, step)
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
def create_learning_rate_fn(
|
| 224 |
-
train_ds_size: int,
|
| 225 |
-
train_batch_size: int,
|
| 226 |
-
num_train_epochs: int,
|
| 227 |
-
num_warmup_steps: int,
|
| 228 |
-
learning_rate: float,
|
| 229 |
-
) -> Callable[[int], jnp.array]:
|
| 230 |
-
"""Returns a linear warmup, linear_decay learning rate function."""
|
| 231 |
-
steps_per_epoch = train_ds_size // train_batch_size
|
| 232 |
-
num_train_steps = steps_per_epoch * num_train_epochs
|
| 233 |
-
warmup_fn = optax.linear_schedule(
|
| 234 |
-
init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps
|
| 235 |
-
)
|
| 236 |
-
decay_fn = optax.linear_schedule(
|
| 237 |
-
init_value=learning_rate,
|
| 238 |
-
end_value=0,
|
| 239 |
-
transition_steps=num_train_steps - num_warmup_steps,
|
| 240 |
-
)
|
| 241 |
-
schedule_fn = optax.join_schedules(
|
| 242 |
-
schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps]
|
| 243 |
-
)
|
| 244 |
-
return schedule_fn
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
def main():
|
| 248 |
-
parser = HfArgumentParser(
|
| 249 |
-
(ModelArguments, DataTrainingArguments, TrainingArguments)
|
| 250 |
-
)
|
| 251 |
-
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
| 252 |
-
# If we pass only one argument to the script and it's the path to a json file,
|
| 253 |
-
# let's parse it to get our arguments.
|
| 254 |
-
model_args, data_args, training_args = parser.parse_json_file(
|
| 255 |
-
json_file=os.path.abspath(sys.argv[1])
|
| 256 |
-
)
|
| 257 |
-
else:
|
| 258 |
-
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
| 259 |
-
|
| 260 |
-
if (
|
| 261 |
-
os.path.exists(training_args.output_dir)
|
| 262 |
-
and os.listdir(training_args.output_dir)
|
| 263 |
-
and training_args.do_train
|
| 264 |
-
and not training_args.overwrite_output_dir
|
| 265 |
-
):
|
| 266 |
-
raise ValueError(
|
| 267 |
-
f"Output directory ({training_args.output_dir}) already exists and is not empty."
|
| 268 |
-
"Use --overwrite_output_dir to overcome."
|
| 269 |
-
)
|
| 270 |
-
|
| 271 |
-
# Make one log on every process with the configuration for debugging.
|
| 272 |
-
logging.basicConfig(
|
| 273 |
-
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
| 274 |
-
datefmt="%m/%d/%Y %H:%M:%S",
|
| 275 |
-
level=logging.INFO,
|
| 276 |
-
)
|
| 277 |
-
# Setup logging, we only want one process per machine to log things on the screen.
|
| 278 |
-
logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
|
| 279 |
-
if jax.process_index() == 0:
|
| 280 |
-
transformers.utils.logging.set_verbosity_info()
|
| 281 |
-
else:
|
| 282 |
-
transformers.utils.logging.set_verbosity_error()
|
| 283 |
-
|
| 284 |
-
# Set the verbosity to info of the Transformers logger (on main process only):
|
| 285 |
-
logger.info(f"Training/evaluation parameters {training_args}")
|
| 286 |
-
|
| 287 |
-
if model_args.tokenizer_name:
|
| 288 |
-
tokenizer = AutoTokenizer.from_pretrained(
|
| 289 |
-
model_args.tokenizer_name,
|
| 290 |
-
cache_dir=model_args.cache_dir,
|
| 291 |
-
use_fast=model_args.use_fast_tokenizer,
|
| 292 |
-
)
|
| 293 |
-
elif model_args.text_model_name_or_path:
|
| 294 |
-
tokenizer = AutoTokenizer.from_pretrained(
|
| 295 |
-
model_args.text_model_name_or_path,
|
| 296 |
-
cache_dir=model_args.cache_dir,
|
| 297 |
-
use_fast=model_args.use_fast_tokenizer,
|
| 298 |
-
)
|
| 299 |
-
else:
|
| 300 |
-
raise ValueError(
|
| 301 |
-
"You are instantiating a new tokenizer from scratch. This is not supported by this script."
|
| 302 |
-
"You can do it from another script, save it, and load it from here, using --tokenizer_name."
|
| 303 |
-
)
|
| 304 |
-
|
| 305 |
-
model = FlaxHybridCLIP.from_text_vision_pretrained(
|
| 306 |
-
model_args.text_model_name_or_path,
|
| 307 |
-
model_args.vision_model_name_or_path,
|
| 308 |
-
seed=training_args.seed,
|
| 309 |
-
dtype=getattr(jnp, model_args.dtype),
|
| 310 |
-
text_from_pt=model_args.from_pt,
|
| 311 |
-
vision_from_pt=model_args.from_pt,
|
| 312 |
-
)
|
| 313 |
-
config = model.config
|
| 314 |
-
# set seed for torch dataloaders
|
| 315 |
-
set_seed(training_args.seed)
|
| 316 |
-
|
| 317 |
-
# Initialize torchvision transforms and jit them for faster processing
|
| 318 |
-
preprocess = Transform(config.vision_config.image_size)
|
| 319 |
-
preprocess = torch.jit.script(preprocess)
|
| 320 |
-
|
| 321 |
-
# Initialize the image-text dataset
|
| 322 |
-
train_dataset = ImageTextDataset(
|
| 323 |
-
data_args.data_dir,
|
| 324 |
-
data_args.train_file,
|
| 325 |
-
captions_per_image=2,
|
| 326 |
-
transform=preprocess,
|
| 327 |
-
)
|
| 328 |
-
|
| 329 |
-
eval_dataset = ImageTextDataset(
|
| 330 |
-
data_args.data_dir,
|
| 331 |
-
data_args.validation_file,
|
| 332 |
-
captions_per_image=1,
|
| 333 |
-
transform=preprocess,
|
| 334 |
-
)
|
| 335 |
-
# Import Translation Pipeline
|
| 336 |
-
mt = Pororo(task="translation", lang="multi")
|
| 337 |
-
|
| 338 |
-
# Store some constant
|
| 339 |
-
num_epochs = int(training_args.num_train_epochs)
|
| 340 |
-
train_batch_size = (
|
| 341 |
-
int(training_args.per_device_train_batch_size) * jax.device_count()
|
| 342 |
-
)
|
| 343 |
-
eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
|
| 344 |
-
steps_per_epoch = len(train_dataset) // train_batch_size
|
| 345 |
-
total_train_steps = steps_per_epoch * num_epochs
|
| 346 |
-
|
| 347 |
-
# Use collate function to tokenizer the text and convert the processed images to numpy
|
| 348 |
-
def collate_fn(examples):
|
| 349 |
-
pixel_values = (
|
| 350 |
-
torch.stack([example[0] for example in examples])
|
| 351 |
-
.permute(0, 2, 3, 1)
|
| 352 |
-
.numpy()
|
| 353 |
-
)
|
| 354 |
-
en_captions = [example[1] for example in examples]
|
| 355 |
-
captions = [mt(text, src="en", tgt="ko") for text in en_captions]
|
| 356 |
-
inputs = tokenizer(
|
| 357 |
-
captions,
|
| 358 |
-
max_length=data_args.max_seq_length,
|
| 359 |
-
padding="max_length",
|
| 360 |
-
return_tensors="np",
|
| 361 |
-
)
|
| 362 |
-
|
| 363 |
-
batch = {
|
| 364 |
-
"pixel_values": pixel_values,
|
| 365 |
-
"input_ids": inputs["input_ids"],
|
| 366 |
-
"attention_mask": inputs["attention_mask"],
|
| 367 |
-
}
|
| 368 |
-
|
| 369 |
-
return batch
|
| 370 |
-
|
| 371 |
-
# Create data loaders
|
| 372 |
-
train_loader = torch.utils.data.DataLoader(
|
| 373 |
-
train_dataset,
|
| 374 |
-
batch_size=train_batch_size,
|
| 375 |
-
shuffle=True,
|
| 376 |
-
num_workers=data_args.preprocessing_num_workers,
|
| 377 |
-
persistent_workers=True,
|
| 378 |
-
drop_last=True,
|
| 379 |
-
collate_fn=collate_fn,
|
| 380 |
-
)
|
| 381 |
-
|
| 382 |
-
eval_loader = torch.utils.data.DataLoader(
|
| 383 |
-
eval_dataset,
|
| 384 |
-
batch_size=eval_batch_size,
|
| 385 |
-
shuffle=False,
|
| 386 |
-
num_workers=data_args.preprocessing_num_workers,
|
| 387 |
-
persistent_workers=True,
|
| 388 |
-
drop_last=True,
|
| 389 |
-
collate_fn=collate_fn,
|
| 390 |
-
)
|
| 391 |
-
|
| 392 |
-
# Enable tensorboard only on the master node
|
| 393 |
-
if has_tensorboard and jax.process_index() == 0:
|
| 394 |
-
summary_writer = SummaryWriter(
|
| 395 |
-
log_dir=Path(training_args.output_dir).joinpath("logs").as_posix()
|
| 396 |
-
)
|
| 397 |
-
|
| 398 |
-
# Initialize our training
|
| 399 |
-
rng = jax.random.PRNGKey(training_args.seed)
|
| 400 |
-
rng, dropout_rng = jax.random.split(rng)
|
| 401 |
-
|
| 402 |
-
# Create learning rate schedule
|
| 403 |
-
linear_decay_lr_schedule_fn = create_learning_rate_fn(
|
| 404 |
-
len(train_dataset),
|
| 405 |
-
train_batch_size,
|
| 406 |
-
training_args.num_train_epochs,
|
| 407 |
-
training_args.warmup_steps,
|
| 408 |
-
training_args.learning_rate,
|
| 409 |
-
)
|
| 410 |
-
|
| 411 |
-
# create adam optimizer
|
| 412 |
-
adamw = optax.adamw(
|
| 413 |
-
learning_rate=linear_decay_lr_schedule_fn,
|
| 414 |
-
b1=training_args.adam_beta1,
|
| 415 |
-
b2=training_args.adam_beta2,
|
| 416 |
-
eps=training_args.adam_epsilon,
|
| 417 |
-
weight_decay=training_args.weight_decay,
|
| 418 |
-
)
|
| 419 |
-
|
| 420 |
-
# Setup train state
|
| 421 |
-
state = TrainState.create(
|
| 422 |
-
apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng
|
| 423 |
-
)
|
| 424 |
-
|
| 425 |
-
def cross_entropy(logits, axis):
|
| 426 |
-
logprobs = jax.nn.log_softmax(logits, axis=axis)
|
| 427 |
-
nll = jnp.diag(logprobs)
|
| 428 |
-
ce = -jnp.mean(nll)
|
| 429 |
-
return ce
|
| 430 |
-
|
| 431 |
-
def clip_loss(similarity):
|
| 432 |
-
loss = (
|
| 433 |
-
cross_entropy(similarity, axis=0) + cross_entropy(similarity, axis=1)
|
| 434 |
-
) / 2
|
| 435 |
-
return loss
|
| 436 |
-
|
| 437 |
-
# Define gradient update step fn
|
| 438 |
-
def train_step(state, batch):
|
| 439 |
-
dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
|
| 440 |
-
|
| 441 |
-
def compute_loss(params):
|
| 442 |
-
logits = state.apply_fn(
|
| 443 |
-
**batch, params=params, dropout_rng=dropout_rng, train=True
|
| 444 |
-
)[0]
|
| 445 |
-
loss = clip_loss(logits)
|
| 446 |
-
return loss
|
| 447 |
-
|
| 448 |
-
grad_fn = jax.value_and_grad(compute_loss)
|
| 449 |
-
loss, grad = grad_fn(state.params)
|
| 450 |
-
grad = jax.lax.pmean(grad, "batch")
|
| 451 |
-
|
| 452 |
-
new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)
|
| 453 |
-
|
| 454 |
-
metrics = {
|
| 455 |
-
"loss": loss,
|
| 456 |
-
"learning_rate": linear_decay_lr_schedule_fn(state.step),
|
| 457 |
-
}
|
| 458 |
-
metrics = jax.lax.pmean(metrics, axis_name="batch")
|
| 459 |
-
|
| 460 |
-
return new_state, metrics
|
| 461 |
-
|
| 462 |
-
# Define eval fn
|
| 463 |
-
def eval_step(params, batch):
|
| 464 |
-
logits = model(**batch, params=params, train=False)[0]
|
| 465 |
-
loss = clip_loss(logits)
|
| 466 |
-
|
| 467 |
-
# summarize metrics
|
| 468 |
-
metrics = {"loss": loss}
|
| 469 |
-
metrics = jax.lax.pmean(metrics, axis_name="batch")
|
| 470 |
-
return metrics
|
| 471 |
-
|
| 472 |
-
# Create parallel version of the train and eval step
|
| 473 |
-
p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
|
| 474 |
-
p_eval_step = jax.pmap(eval_step, "batch")
|
| 475 |
-
|
| 476 |
-
# Replicate the train state on each device
|
| 477 |
-
state = state.replicate()
|
| 478 |
-
|
| 479 |
-
logger.info("***** Running training *****")
|
| 480 |
-
logger.info(f" Num examples = {len(train_dataset)}")
|
| 481 |
-
logger.info(f" Num Epochs = {num_epochs}")
|
| 482 |
-
logger.info(
|
| 483 |
-
f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}"
|
| 484 |
-
)
|
| 485 |
-
logger.info(
|
| 486 |
-
f" Total train batch size (w. parallel & distributed) = {train_batch_size}"
|
| 487 |
-
)
|
| 488 |
-
logger.info(f" Total optimization steps = {total_train_steps}")
|
| 489 |
-
|
| 490 |
-
train_time = 0
|
| 491 |
-
# Create sampling rng
|
| 492 |
-
rng, input_rng = jax.random.split(rng)
|
| 493 |
-
|
| 494 |
-
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
|
| 495 |
-
for epoch in epochs:
|
| 496 |
-
# ======================== Training ================================
|
| 497 |
-
train_start = time.time()
|
| 498 |
-
|
| 499 |
-
# Create sampling rng
|
| 500 |
-
rng, input_rng = jax.random.split(rng)
|
| 501 |
-
train_metrics = []
|
| 502 |
-
|
| 503 |
-
steps_per_epoch = len(train_dataset) // train_batch_size
|
| 504 |
-
train_step_progress_bar = tqdm(
|
| 505 |
-
total=steps_per_epoch, desc="Training...", position=1, leave=False
|
| 506 |
-
)
|
| 507 |
-
# train
|
| 508 |
-
for batch in train_loader:
|
| 509 |
-
batch = shard(batch)
|
| 510 |
-
state, train_metric = p_train_step(state, batch)
|
| 511 |
-
train_metrics.append(train_metric)
|
| 512 |
-
|
| 513 |
-
train_step_progress_bar.update(1)
|
| 514 |
-
|
| 515 |
-
train_time += time.time() - train_start
|
| 516 |
-
|
| 517 |
-
train_metric = unreplicate(train_metric)
|
| 518 |
-
|
| 519 |
-
train_step_progress_bar.close()
|
| 520 |
-
epochs.write(
|
| 521 |
-
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
|
| 522 |
-
)
|
| 523 |
-
|
| 524 |
-
# ======================== Evaluating ==============================
|
| 525 |
-
eval_metrics = []
|
| 526 |
-
eval_steps = len(eval_dataset) // eval_batch_size
|
| 527 |
-
eval_step_progress_bar = tqdm(
|
| 528 |
-
total=eval_steps, desc="Evaluating...", position=2, leave=False
|
| 529 |
-
)
|
| 530 |
-
for batch in eval_loader:
|
| 531 |
-
# Model forward
|
| 532 |
-
batch = shard(batch)
|
| 533 |
-
metrics = p_eval_step(state.params, batch)
|
| 534 |
-
eval_metrics.append(metrics)
|
| 535 |
-
|
| 536 |
-
eval_step_progress_bar.update(1)
|
| 537 |
-
|
| 538 |
-
# normalize eval metrics
|
| 539 |
-
eval_metrics = get_metrics(eval_metrics)
|
| 540 |
-
|
| 541 |
-
eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
|
| 542 |
-
|
| 543 |
-
# Print metrics and update progress bar
|
| 544 |
-
eval_step_progress_bar.close()
|
| 545 |
-
desc = (
|
| 546 |
-
f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']})"
|
| 547 |
-
)
|
| 548 |
-
epochs.write(desc)
|
| 549 |
-
epochs.desc = desc
|
| 550 |
-
|
| 551 |
-
# Save metrics
|
| 552 |
-
if has_tensorboard and jax.process_index() == 0:
|
| 553 |
-
cur_step = epoch * (len(train_dataset) // train_batch_size)
|
| 554 |
-
write_metric(
|
| 555 |
-
summary_writer, train_metrics, eval_metrics, train_time, cur_step
|
| 556 |
-
)
|
| 557 |
-
|
| 558 |
-
# save checkpoint after each epoch and push checkpoint to the hub
|
| 559 |
-
if jax.process_index() == 0:
|
| 560 |
-
params = jax.device_get(unreplicate(state.params))
|
| 561 |
-
model.save_pretrained(
|
| 562 |
-
training_args.output_dir,
|
| 563 |
-
params=params,
|
| 564 |
-
push_to_hub=training_args.push_to_hub,
|
| 565 |
-
commit_message=f"Saving weights and logs of epoch {epoch+1}",
|
| 566 |
-
)
|
| 567 |
-
|
| 568 |
-
|
| 569 |
-
if __name__ == "__main__":
|
| 570 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
train.sh
CHANGED
|
@@ -1,15 +1,14 @@
|
|
| 1 |
python run_hybrid_clip.py \
|
| 2 |
-
--output_dir
|
| 3 |
--text_model_name_or_path="klue/roberta-large" \
|
| 4 |
--vision_model_name_or_path="openai/clip-vit-base-patch32" \
|
| 5 |
--tokenizer_name="klue/roberta-large" \
|
| 6 |
-
--train_file="
|
| 7 |
-
--validation_file="
|
| 8 |
--do_train --do_eval \
|
| 9 |
--num_train_epochs="40" --max_seq_length 96 \
|
| 10 |
--per_device_train_batch_size="64" \
|
| 11 |
--per_device_eval_batch_size="64" \
|
| 12 |
--learning_rate="5e-5" --warmup_steps="0" --weight_decay 0.1 \
|
| 13 |
--overwrite_output_dir \
|
| 14 |
-
--preprocessing_num_workers 32
|
| 15 |
-
--push_to_hub
|
|
|
|
| 1 |
python run_hybrid_clip.py \
|
| 2 |
+
--output_dir="models/coco_only" \
|
| 3 |
--text_model_name_or_path="klue/roberta-large" \
|
| 4 |
--vision_model_name_or_path="openai/clip-vit-base-patch32" \
|
| 5 |
--tokenizer_name="klue/roberta-large" \
|
| 6 |
+
--train_file="../dataset/coco/train_annotations.json" \
|
| 7 |
+
--validation_file="../dataset/coco/valid_annotations.json" \
|
| 8 |
--do_train --do_eval \
|
| 9 |
--num_train_epochs="40" --max_seq_length 96 \
|
| 10 |
--per_device_train_batch_size="64" \
|
| 11 |
--per_device_eval_batch_size="64" \
|
| 12 |
--learning_rate="5e-5" --warmup_steps="0" --weight_decay 0.1 \
|
| 13 |
--overwrite_output_dir \
|
| 14 |
+
--preprocessing_num_workers 32
|
|
|