Spaces:
Runtime error
Runtime error
Merge pull request #118 from borisdayma/feat-optim
Browse files- src/dalle_mini/data.py +11 -12
- tools/inference/inference_pipeline.ipynb +46 -19
- tools/train/train.py +129 -101
src/dalle_mini/data.py
CHANGED
|
@@ -161,7 +161,7 @@ class Dataset:
|
|
| 161 |
):
|
| 162 |
"""
|
| 163 |
Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
|
| 164 |
-
Shuffle batches if
|
| 165 |
"""
|
| 166 |
steps_per_epoch = len(dataset) // batch_size
|
| 167 |
|
|
@@ -182,19 +182,20 @@ class Dataset:
|
|
| 182 |
yield batch
|
| 183 |
|
| 184 |
def _dataloader_datasets_streaming(
|
| 185 |
-
dataset: Dataset, batch_size: int, epoch: int
|
| 186 |
):
|
| 187 |
-
# epoch is only use for multi-host
|
| 188 |
keys = ["input_ids", "attention_mask", "labels", "decoder_input_ids"]
|
| 189 |
batch = {k: [] for k in keys}
|
| 190 |
-
first_loop = True
|
| 191 |
-
while self.multi_hosts or first_loop:
|
| 192 |
# in multi-host, we run forever (no epoch) as hosts need to stop
|
| 193 |
-
# at the same time and
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
|
|
|
| 197 |
dataset.set_epoch(epoch)
|
|
|
|
| 198 |
for item in dataset:
|
| 199 |
for k, v in item.items():
|
| 200 |
batch[k].append(v)
|
|
@@ -213,9 +214,7 @@ class Dataset:
|
|
| 213 |
raise ValueError(f'split must be "train" or "eval", got {split}')
|
| 214 |
|
| 215 |
if self.streaming:
|
| 216 |
-
|
| 217 |
-
ds.set_epoch(epoch)
|
| 218 |
-
return _dataloader_datasets_streaming(ds, batch_size, epoch)
|
| 219 |
else:
|
| 220 |
if split == "train":
|
| 221 |
self.rng_dataset, input_rng = jax.random.split(self.rng_dataset)
|
|
|
|
| 161 |
):
|
| 162 |
"""
|
| 163 |
Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
|
| 164 |
+
Shuffle batches if rng is set.
|
| 165 |
"""
|
| 166 |
steps_per_epoch = len(dataset) // batch_size
|
| 167 |
|
|
|
|
| 182 |
yield batch
|
| 183 |
|
| 184 |
def _dataloader_datasets_streaming(
|
| 185 |
+
dataset: Dataset, split: str, batch_size: int, epoch: int
|
| 186 |
):
|
|
|
|
| 187 |
keys = ["input_ids", "attention_mask", "labels", "decoder_input_ids"]
|
| 188 |
batch = {k: [] for k in keys}
|
| 189 |
+
first_loop = True # stop after one loop in some cases
|
| 190 |
+
while (self.multi_hosts and split == "train") or first_loop:
|
| 191 |
# in multi-host, we run forever (no epoch) as hosts need to stop
|
| 192 |
+
# at the same time and training data may not be split equally
|
| 193 |
+
# For validation data we put the entire set on each host as we could lose
|
| 194 |
+
# too many samples on pods
|
| 195 |
+
if epoch is not None:
|
| 196 |
+
# reshuffle training data at each epoch (not applicable with validation set)
|
| 197 |
dataset.set_epoch(epoch)
|
| 198 |
+
epoch += 1
|
| 199 |
for item in dataset:
|
| 200 |
for k, v in item.items():
|
| 201 |
batch[k].append(v)
|
|
|
|
| 214 |
raise ValueError(f'split must be "train" or "eval", got {split}')
|
| 215 |
|
| 216 |
if self.streaming:
|
| 217 |
+
return _dataloader_datasets_streaming(ds, split, batch_size, epoch)
|
|
|
|
|
|
|
| 218 |
else:
|
| 219 |
if split == "train":
|
| 220 |
self.rng_dataset, input_rng = jax.random.split(self.rng_dataset)
|
tools/inference/inference_pipeline.ipynb
CHANGED
|
@@ -70,15 +70,15 @@
|
|
| 70 |
"# Model references\n",
|
| 71 |
"\n",
|
| 72 |
"# dalle-mini\n",
|
| 73 |
-
"DALLE_MODEL =
|
| 74 |
"DALLE_COMMIT_ID = None # used only with 🤗 hub\n",
|
| 75 |
"\n",
|
| 76 |
"# VQGAN model\n",
|
| 77 |
-
"VQGAN_REPO =
|
| 78 |
-
"VQGAN_COMMIT_ID =
|
| 79 |
"\n",
|
| 80 |
"# CLIP model\n",
|
| 81 |
-
"CLIP_REPO =
|
| 82 |
"CLIP_COMMIT_ID = None"
|
| 83 |
]
|
| 84 |
},
|
|
@@ -121,18 +121,28 @@
|
|
| 121 |
"import wandb\n",
|
| 122 |
"\n",
|
| 123 |
"# Load dalle-mini\n",
|
| 124 |
-
"if
|
| 125 |
" # wandb artifact\n",
|
| 126 |
" artifact = wandb.Api().artifact(DALLE_MODEL)\n",
|
| 127 |
" # we only download required files (no need for opt_state which is large)\n",
|
| 128 |
-
" model_files = [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
" for f in model_files:\n",
|
| 130 |
-
" artifact.get_path(f).download(
|
| 131 |
-
" model = DalleBart.from_pretrained(
|
| 132 |
-
" tokenizer = AutoTokenizer.from_pretrained(
|
| 133 |
"else:\n",
|
| 134 |
" # local folder or 🤗 Hub\n",
|
| 135 |
-
" model = DalleBart.from_pretrained(
|
|
|
|
|
|
|
| 136 |
" tokenizer = AutoTokenizer.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID)\n",
|
| 137 |
"\n",
|
| 138 |
"# Load VQGAN\n",
|
|
@@ -191,7 +201,7 @@
|
|
| 191 |
"from functools import partial\n",
|
| 192 |
"\n",
|
| 193 |
"# model inference\n",
|
| 194 |
-
"@partial(jax.pmap, axis_name=\"batch\", static_broadcasted_argnums=(3,4))\n",
|
| 195 |
"def p_generate(tokenized_prompt, key, params, top_k, top_p):\n",
|
| 196 |
" return model.generate(\n",
|
| 197 |
" **tokenized_prompt,\n",
|
|
@@ -203,11 +213,13 @@
|
|
| 203 |
" top_p=top_p\n",
|
| 204 |
" )\n",
|
| 205 |
"\n",
|
|
|
|
| 206 |
"# decode images\n",
|
| 207 |
"@partial(jax.pmap, axis_name=\"batch\")\n",
|
| 208 |
"def p_decode(indices, params):\n",
|
| 209 |
" return vqgan.decode_code(indices, params=params)\n",
|
| 210 |
"\n",
|
|
|
|
| 211 |
"# score images\n",
|
| 212 |
"@partial(jax.pmap, axis_name=\"batch\")\n",
|
| 213 |
"def p_clip(inputs, params):\n",
|
|
@@ -235,7 +247,7 @@
|
|
| 235 |
"import random\n",
|
| 236 |
"\n",
|
| 237 |
"# create a random key\n",
|
| 238 |
-
"seed = random.randint(0, 2**32-1)\n",
|
| 239 |
"key = jax.random.PRNGKey(seed)"
|
| 240 |
]
|
| 241 |
},
|
|
@@ -287,7 +299,7 @@
|
|
| 287 |
},
|
| 288 |
"outputs": [],
|
| 289 |
"source": [
|
| 290 |
-
"prompt =
|
| 291 |
]
|
| 292 |
},
|
| 293 |
{
|
|
@@ -323,7 +335,13 @@
|
|
| 323 |
"repeated_prompts = [processed_prompt] * jax.device_count()\n",
|
| 324 |
"\n",
|
| 325 |
"# tokenize\n",
|
| 326 |
-
"tokenized_prompt = tokenizer(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 327 |
"tokenized_prompt"
|
| 328 |
]
|
| 329 |
},
|
|
@@ -408,12 +426,14 @@
|
|
| 408 |
" # get a new key\n",
|
| 409 |
" key, subkey = jax.random.split(key)\n",
|
| 410 |
" # generate images\n",
|
| 411 |
-
" encoded_images = p_generate(
|
|
|
|
|
|
|
| 412 |
" # remove BOS\n",
|
| 413 |
" encoded_images = encoded_images.sequences[..., 1:]\n",
|
| 414 |
" # decode images\n",
|
| 415 |
" decoded_images = p_decode(encoded_images, vqgan_params)\n",
|
| 416 |
-
" decoded_images = decoded_images.clip(0
|
| 417 |
" for img in decoded_images:\n",
|
| 418 |
" images.append(Image.fromarray(np.asarray(img * 255, dtype=np.uint8)))"
|
| 419 |
]
|
|
@@ -436,7 +456,14 @@
|
|
| 436 |
"outputs": [],
|
| 437 |
"source": [
|
| 438 |
"# get clip scores\n",
|
| 439 |
-
"clip_inputs = processor(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 440 |
"logits = p_clip(shard(clip_inputs), clip_params)\n",
|
| 441 |
"logits = logits.squeeze().flatten()"
|
| 442 |
]
|
|
@@ -458,10 +485,10 @@
|
|
| 458 |
},
|
| 459 |
"outputs": [],
|
| 460 |
"source": [
|
| 461 |
-
"print(f
|
| 462 |
"for idx in logits.argsort()[::-1]:\n",
|
| 463 |
" display(images[idx])\n",
|
| 464 |
-
" print(f
|
| 465 |
]
|
| 466 |
}
|
| 467 |
],
|
|
|
|
| 70 |
"# Model references\n",
|
| 71 |
"\n",
|
| 72 |
"# dalle-mini\n",
|
| 73 |
+
"DALLE_MODEL = \"dalle-mini/dalle-mini/model-3bqwu04f:latest\" # can be wandb artifact or 🤗 Hub or local folder\n",
|
| 74 |
"DALLE_COMMIT_ID = None # used only with 🤗 hub\n",
|
| 75 |
"\n",
|
| 76 |
"# VQGAN model\n",
|
| 77 |
+
"VQGAN_REPO = \"dalle-mini/vqgan_imagenet_f16_16384\"\n",
|
| 78 |
+
"VQGAN_COMMIT_ID = \"e93a26e7707683d349bf5d5c41c5b0ef69b677a9\"\n",
|
| 79 |
"\n",
|
| 80 |
"# CLIP model\n",
|
| 81 |
+
"CLIP_REPO = \"openai/clip-vit-base-patch16\"\n",
|
| 82 |
"CLIP_COMMIT_ID = None"
|
| 83 |
]
|
| 84 |
},
|
|
|
|
| 121 |
"import wandb\n",
|
| 122 |
"\n",
|
| 123 |
"# Load dalle-mini\n",
|
| 124 |
+
"if \":\" in DALLE_MODEL:\n",
|
| 125 |
" # wandb artifact\n",
|
| 126 |
" artifact = wandb.Api().artifact(DALLE_MODEL)\n",
|
| 127 |
" # we only download required files (no need for opt_state which is large)\n",
|
| 128 |
+
" model_files = [\n",
|
| 129 |
+
" \"config.json\",\n",
|
| 130 |
+
" \"flax_model.msgpack\",\n",
|
| 131 |
+
" \"merges.txt\",\n",
|
| 132 |
+
" \"special_tokens_map.json\",\n",
|
| 133 |
+
" \"tokenizer.json\",\n",
|
| 134 |
+
" \"tokenizer_config.json\",\n",
|
| 135 |
+
" \"vocab.json\",\n",
|
| 136 |
+
" ]\n",
|
| 137 |
" for f in model_files:\n",
|
| 138 |
+
" artifact.get_path(f).download(\"model\")\n",
|
| 139 |
+
" model = DalleBart.from_pretrained(\"model\", dtype=dtype, abstract_init=True)\n",
|
| 140 |
+
" tokenizer = AutoTokenizer.from_pretrained(\"model\")\n",
|
| 141 |
"else:\n",
|
| 142 |
" # local folder or 🤗 Hub\n",
|
| 143 |
+
" model = DalleBart.from_pretrained(\n",
|
| 144 |
+
" DALLE_MODEL, revision=DALLE_COMMIT_ID, dtype=dtype, abstract_init=True\n",
|
| 145 |
+
" )\n",
|
| 146 |
" tokenizer = AutoTokenizer.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID)\n",
|
| 147 |
"\n",
|
| 148 |
"# Load VQGAN\n",
|
|
|
|
| 201 |
"from functools import partial\n",
|
| 202 |
"\n",
|
| 203 |
"# model inference\n",
|
| 204 |
+
"@partial(jax.pmap, axis_name=\"batch\", static_broadcasted_argnums=(3, 4))\n",
|
| 205 |
"def p_generate(tokenized_prompt, key, params, top_k, top_p):\n",
|
| 206 |
" return model.generate(\n",
|
| 207 |
" **tokenized_prompt,\n",
|
|
|
|
| 213 |
" top_p=top_p\n",
|
| 214 |
" )\n",
|
| 215 |
"\n",
|
| 216 |
+
"\n",
|
| 217 |
"# decode images\n",
|
| 218 |
"@partial(jax.pmap, axis_name=\"batch\")\n",
|
| 219 |
"def p_decode(indices, params):\n",
|
| 220 |
" return vqgan.decode_code(indices, params=params)\n",
|
| 221 |
"\n",
|
| 222 |
+
"\n",
|
| 223 |
"# score images\n",
|
| 224 |
"@partial(jax.pmap, axis_name=\"batch\")\n",
|
| 225 |
"def p_clip(inputs, params):\n",
|
|
|
|
| 247 |
"import random\n",
|
| 248 |
"\n",
|
| 249 |
"# create a random key\n",
|
| 250 |
+
"seed = random.randint(0, 2 ** 32 - 1)\n",
|
| 251 |
"key = jax.random.PRNGKey(seed)"
|
| 252 |
]
|
| 253 |
},
|
|
|
|
| 299 |
},
|
| 300 |
"outputs": [],
|
| 301 |
"source": [
|
| 302 |
+
"prompt = \"a red T-shirt\""
|
| 303 |
]
|
| 304 |
},
|
| 305 |
{
|
|
|
|
| 335 |
"repeated_prompts = [processed_prompt] * jax.device_count()\n",
|
| 336 |
"\n",
|
| 337 |
"# tokenize\n",
|
| 338 |
+
"tokenized_prompt = tokenizer(\n",
|
| 339 |
+
" repeated_prompts,\n",
|
| 340 |
+
" return_tensors=\"jax\",\n",
|
| 341 |
+
" padding=\"max_length\",\n",
|
| 342 |
+
" truncation=True,\n",
|
| 343 |
+
" max_length=128,\n",
|
| 344 |
+
").data\n",
|
| 345 |
"tokenized_prompt"
|
| 346 |
]
|
| 347 |
},
|
|
|
|
| 426 |
" # get a new key\n",
|
| 427 |
" key, subkey = jax.random.split(key)\n",
|
| 428 |
" # generate images\n",
|
| 429 |
+
" encoded_images = p_generate(\n",
|
| 430 |
+
" tokenized_prompt, shard_prng_key(subkey), model_params, gen_top_k, gen_top_p\n",
|
| 431 |
+
" )\n",
|
| 432 |
" # remove BOS\n",
|
| 433 |
" encoded_images = encoded_images.sequences[..., 1:]\n",
|
| 434 |
" # decode images\n",
|
| 435 |
" decoded_images = p_decode(encoded_images, vqgan_params)\n",
|
| 436 |
+
" decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3))\n",
|
| 437 |
" for img in decoded_images:\n",
|
| 438 |
" images.append(Image.fromarray(np.asarray(img * 255, dtype=np.uint8)))"
|
| 439 |
]
|
|
|
|
| 456 |
"outputs": [],
|
| 457 |
"source": [
|
| 458 |
"# get clip scores\n",
|
| 459 |
+
"clip_inputs = processor(\n",
|
| 460 |
+
" text=[prompt] * jax.device_count(),\n",
|
| 461 |
+
" images=images,\n",
|
| 462 |
+
" return_tensors=\"np\",\n",
|
| 463 |
+
" padding=\"max_length\",\n",
|
| 464 |
+
" max_length=77,\n",
|
| 465 |
+
" truncation=True,\n",
|
| 466 |
+
").data\n",
|
| 467 |
"logits = p_clip(shard(clip_inputs), clip_params)\n",
|
| 468 |
"logits = logits.squeeze().flatten()"
|
| 469 |
]
|
|
|
|
| 485 |
},
|
| 486 |
"outputs": [],
|
| 487 |
"source": [
|
| 488 |
+
"print(f\"Prompt: {prompt}\\n\")\n",
|
| 489 |
"for idx in logits.argsort()[::-1]:\n",
|
| 490 |
" display(images[idx])\n",
|
| 491 |
+
" print(f\"Score: {logits[idx]:.2f}\\n\")"
|
| 492 |
]
|
| 493 |
}
|
| 494 |
],
|
tools/train/train.py
CHANGED
|
@@ -65,7 +65,7 @@ class ModelArguments:
|
|
| 65 |
config_name: Optional[str] = field(
|
| 66 |
default=None,
|
| 67 |
metadata={
|
| 68 |
-
"help": "Pretrained config name or path if not the same as
|
| 69 |
},
|
| 70 |
)
|
| 71 |
tokenizer_name: Optional[str] = field(
|
|
@@ -77,7 +77,7 @@ class ModelArguments:
|
|
| 77 |
dtype: Optional[str] = field(
|
| 78 |
default="float32",
|
| 79 |
metadata={
|
| 80 |
-
"help": "Floating-point format in which the
|
| 81 |
},
|
| 82 |
)
|
| 83 |
|
|
@@ -106,11 +106,15 @@ class DataTrainingArguments:
|
|
| 106 |
)
|
| 107 |
train_file: Optional[str] = field(
|
| 108 |
default=None,
|
| 109 |
-
metadata={
|
|
|
|
|
|
|
| 110 |
)
|
| 111 |
validation_file: Optional[str] = field(
|
| 112 |
default=None,
|
| 113 |
-
metadata={
|
|
|
|
|
|
|
| 114 |
)
|
| 115 |
# data loading should not be a bottleneck so we use "streaming" mode by default
|
| 116 |
streaming: Optional[bool] = field(
|
|
@@ -132,15 +136,13 @@ class DataTrainingArguments:
|
|
| 132 |
max_train_samples: Optional[int] = field(
|
| 133 |
default=None,
|
| 134 |
metadata={
|
| 135 |
-
"help": "For debugging purposes or quicker training, truncate the number of training examples
|
| 136 |
-
"value if set."
|
| 137 |
},
|
| 138 |
)
|
| 139 |
max_eval_samples: Optional[int] = field(
|
| 140 |
default=None,
|
| 141 |
metadata={
|
| 142 |
-
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples
|
| 143 |
-
"value if set."
|
| 144 |
},
|
| 145 |
)
|
| 146 |
preprocessing_num_workers: Optional[int] = field(
|
|
@@ -191,42 +193,40 @@ class TrainingArguments:
|
|
| 191 |
|
| 192 |
do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
|
| 193 |
do_eval: bool = field(
|
| 194 |
-
default=False, metadata={"help": "Whether to run eval on the
|
| 195 |
)
|
| 196 |
|
| 197 |
per_device_train_batch_size: int = field(
|
| 198 |
-
default=8, metadata={"help": "Batch size per GPU/TPU
|
| 199 |
)
|
| 200 |
per_device_eval_batch_size: int = field(
|
| 201 |
-
default=8, metadata={"help": "Batch size per GPU/TPU
|
| 202 |
)
|
| 203 |
|
| 204 |
gradient_accumulation_steps: int = field(
|
| 205 |
default=1,
|
| 206 |
metadata={
|
| 207 |
-
"help": "Number of updates steps to accumulate before performing
|
| 208 |
},
|
| 209 |
)
|
| 210 |
|
| 211 |
learning_rate: float = field(
|
| 212 |
default=5e-5, metadata={"help": "The initial learning rate."}
|
| 213 |
)
|
| 214 |
-
|
| 215 |
-
default=
|
| 216 |
-
metadata={
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
default=False,
|
| 220 |
-
metadata={"help": "Use Distributed Shampoo optimizer instead of AdamW."},
|
| 221 |
-
)
|
| 222 |
-
weight_decay: float = field(
|
| 223 |
-
default=None, metadata={"help": "Weight decay if we apply some."}
|
| 224 |
)
|
| 225 |
-
|
| 226 |
-
|
|
|
|
|
|
|
| 227 |
)
|
| 228 |
-
|
| 229 |
-
default=0.999,
|
|
|
|
| 230 |
)
|
| 231 |
adam_epsilon: float = field(
|
| 232 |
default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."}
|
|
@@ -234,9 +234,47 @@ class TrainingArguments:
|
|
| 234 |
max_grad_norm: float = field(
|
| 235 |
default=1.0, metadata={"help": "Max gradient norm for Adafactor."}
|
| 236 |
)
|
| 237 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 238 |
default=False,
|
| 239 |
-
metadata={
|
|
|
|
|
|
|
| 240 |
)
|
| 241 |
|
| 242 |
num_train_epochs: float = field(
|
|
@@ -267,18 +305,18 @@ class TrainingArguments:
|
|
| 267 |
},
|
| 268 |
)
|
| 269 |
|
| 270 |
-
push_to_hub: bool = field(
|
| 271 |
-
default=False,
|
| 272 |
-
metadata={
|
| 273 |
-
"help": "Whether or not to upload the trained model to the model hub after training."
|
| 274 |
-
},
|
| 275 |
-
)
|
| 276 |
-
|
| 277 |
resume_from_checkpoint: Optional[str] = field(
|
| 278 |
default=None,
|
| 279 |
metadata={"help": "Reference to a wandb artifact for resuming training."},
|
| 280 |
)
|
| 281 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 282 |
|
| 283 |
class TrainState(train_state.TrainState):
|
| 284 |
dropout_rng: jnp.ndarray = None
|
|
@@ -309,33 +347,6 @@ class TrainState(train_state.TrainState):
|
|
| 309 |
)
|
| 310 |
|
| 311 |
|
| 312 |
-
def create_learning_rate_fn(
|
| 313 |
-
num_warmup_steps: int,
|
| 314 |
-
learning_rate: float,
|
| 315 |
-
use_decay: bool,
|
| 316 |
-
num_train_steps: int = None, # used only with `use_decay`, typically train_size // batch_size * num_epochs
|
| 317 |
-
) -> Callable[[int], jnp.array]:
|
| 318 |
-
"""Returns a linear warmup, linear_decay learning rate function."""
|
| 319 |
-
if use_decay:
|
| 320 |
-
assert (
|
| 321 |
-
num_train_steps is not None
|
| 322 |
-
), "Learning rate with decay requires number of training steps"
|
| 323 |
-
warmup_fn = optax.linear_schedule(
|
| 324 |
-
init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps
|
| 325 |
-
)
|
| 326 |
-
if not use_decay:
|
| 327 |
-
return warmup_fn
|
| 328 |
-
decay_fn = optax.linear_schedule(
|
| 329 |
-
init_value=learning_rate,
|
| 330 |
-
end_value=0,
|
| 331 |
-
transition_steps=num_train_steps - num_warmup_steps,
|
| 332 |
-
)
|
| 333 |
-
schedule_fn = optax.join_schedules(
|
| 334 |
-
schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps]
|
| 335 |
-
)
|
| 336 |
-
return schedule_fn
|
| 337 |
-
|
| 338 |
-
|
| 339 |
class MetricsLogger:
|
| 340 |
def __init__(self, state):
|
| 341 |
self.step = state.step
|
|
@@ -529,12 +540,37 @@ def main():
|
|
| 529 |
num_params = model.num_params
|
| 530 |
|
| 531 |
# Create learning rate schedule
|
| 532 |
-
|
| 533 |
-
|
| 534 |
-
|
| 535 |
-
|
| 536 |
-
|
| 537 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 538 |
|
| 539 |
# We use Optax's "masking" functionality to not apply weight decay
|
| 540 |
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
|
|
@@ -558,29 +594,22 @@ def main():
|
|
| 558 |
return traverse_util.unflatten_dict(flat_mask)
|
| 559 |
|
| 560 |
# create adam optimizer
|
| 561 |
-
if training_args.
|
| 562 |
-
# We use the default parameters here to initialize adafactor,
|
| 563 |
-
# For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
|
| 564 |
-
optimizer = optax.adafactor(
|
| 565 |
-
learning_rate=learning_rate_fn,
|
| 566 |
-
weight_decay_rate=training_args.weight_decay,
|
| 567 |
-
weight_decay_mask=decay_mask_fn,
|
| 568 |
-
clipping_threshold=training_args.max_grad_norm,
|
| 569 |
-
)
|
| 570 |
-
elif training_args.distributed_shampoo:
|
| 571 |
# parameters from https://github.com/tensorflow/lingvo/blob/03ee9d7cd50764b0424c7c863733c91fc0b053ec/lingvo/jax/optimizers.py#L729
|
| 572 |
# Notes:
|
| 573 |
-
# - mask for weight decay is not implemented
|
| 574 |
optimizer = distributed_shampoo(
|
| 575 |
learning_rate_fn,
|
| 576 |
-
block_size=
|
| 577 |
-
beta1=
|
| 578 |
-
beta2=
|
| 579 |
diagonal_epsilon=1e-10,
|
| 580 |
matrix_epsilon=1e-8,
|
| 581 |
-
weight_decay=
|
| 582 |
-
|
| 583 |
-
|
|
|
|
|
|
|
| 584 |
statistics_compute_steps=1,
|
| 585 |
best_effort_shape_interpretation=True,
|
| 586 |
graft_type=GraftingType.RMSPROP_NORMALIZED,
|
|
@@ -589,23 +618,32 @@ def main():
|
|
| 589 |
batch_axis_name="batch",
|
| 590 |
inverse_failure_threshold=0.1,
|
| 591 |
moving_average_for_momentum=True,
|
| 592 |
-
skip_preconditioning_dim_size_gt=
|
| 593 |
clip_by_scaled_gradient_norm=None,
|
| 594 |
precision=jax.lax.Precision.HIGHEST,
|
| 595 |
-
best_effort_memory_usage_reduction=
|
| 596 |
)
|
| 597 |
|
| 598 |
-
|
| 599 |
optimizer = optax.adamw(
|
| 600 |
learning_rate=learning_rate_fn,
|
| 601 |
-
b1=training_args.
|
| 602 |
-
b2=training_args.
|
| 603 |
eps=training_args.adam_epsilon,
|
| 604 |
weight_decay=training_args.weight_decay
|
| 605 |
if training_args.weight_decay is not None
|
| 606 |
else 0.0,
|
| 607 |
mask=decay_mask_fn,
|
| 608 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 609 |
|
| 610 |
# add gradient accumulation
|
| 611 |
if training_args.gradient_accumulation_steps > 1:
|
|
@@ -821,16 +859,6 @@ def main():
|
|
| 821 |
|
| 822 |
wandb.run.log_artifact(artifact)
|
| 823 |
|
| 824 |
-
# save to the hub
|
| 825 |
-
if training_args.push_to_hub:
|
| 826 |
-
model.save_pretrained(
|
| 827 |
-
training_args.output_dir,
|
| 828 |
-
params=params,
|
| 829 |
-
push_to_hub=training_args.push_to_hub,
|
| 830 |
-
commit_message=f"Saving weights and logs at step {unreplicate(state.step)+1}",
|
| 831 |
-
temp_dir=True, # avoid issues with being in a repository
|
| 832 |
-
)
|
| 833 |
-
|
| 834 |
# init variables
|
| 835 |
last_time = time.perf_counter()
|
| 836 |
train_metrics = None
|
|
@@ -841,7 +869,7 @@ def main():
|
|
| 841 |
metrics_logger.log({"train/epoch": epoch}, step=unreplicate(state.step))
|
| 842 |
|
| 843 |
# Generate an epoch by shuffling sampling indices from the train dataset
|
| 844 |
-
train_loader = dataset.dataloader("train", train_batch_size)
|
| 845 |
# train
|
| 846 |
for batch in tqdm(
|
| 847 |
train_loader,
|
|
|
|
| 65 |
config_name: Optional[str] = field(
|
| 66 |
default=None,
|
| 67 |
metadata={
|
| 68 |
+
"help": "Pretrained config name or path if not the same as model_name_or_path"
|
| 69 |
},
|
| 70 |
)
|
| 71 |
tokenizer_name: Optional[str] = field(
|
|
|
|
| 77 |
dtype: Optional[str] = field(
|
| 78 |
default="float32",
|
| 79 |
metadata={
|
| 80 |
+
"help": "Floating-point format in which the computations will be performed (not the model weights). Choose one of `[float32, float16, bfloat16]`."
|
| 81 |
},
|
| 82 |
)
|
| 83 |
|
|
|
|
| 106 |
)
|
| 107 |
train_file: Optional[str] = field(
|
| 108 |
default=None,
|
| 109 |
+
metadata={
|
| 110 |
+
"help": "The input training data file (glob & braceexpand acceptable)."
|
| 111 |
+
},
|
| 112 |
)
|
| 113 |
validation_file: Optional[str] = field(
|
| 114 |
default=None,
|
| 115 |
+
metadata={
|
| 116 |
+
"help": "An optional input evaluation data file (glob & braceexpand acceptable)."
|
| 117 |
+
},
|
| 118 |
)
|
| 119 |
# data loading should not be a bottleneck so we use "streaming" mode by default
|
| 120 |
streaming: Optional[bool] = field(
|
|
|
|
| 136 |
max_train_samples: Optional[int] = field(
|
| 137 |
default=None,
|
| 138 |
metadata={
|
| 139 |
+
"help": "For debugging purposes or quicker training, truncate the number of training examples."
|
|
|
|
| 140 |
},
|
| 141 |
)
|
| 142 |
max_eval_samples: Optional[int] = field(
|
| 143 |
default=None,
|
| 144 |
metadata={
|
| 145 |
+
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples."
|
|
|
|
| 146 |
},
|
| 147 |
)
|
| 148 |
preprocessing_num_workers: Optional[int] = field(
|
|
|
|
| 193 |
|
| 194 |
do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
|
| 195 |
do_eval: bool = field(
|
| 196 |
+
default=False, metadata={"help": "Whether to run eval on the validation set."}
|
| 197 |
)
|
| 198 |
|
| 199 |
per_device_train_batch_size: int = field(
|
| 200 |
+
default=8, metadata={"help": "Batch size per GPU/TPU/CPU for training."}
|
| 201 |
)
|
| 202 |
per_device_eval_batch_size: int = field(
|
| 203 |
+
default=8, metadata={"help": "Batch size per GPU/TPU/CPU for evaluation."}
|
| 204 |
)
|
| 205 |
|
| 206 |
gradient_accumulation_steps: int = field(
|
| 207 |
default=1,
|
| 208 |
metadata={
|
| 209 |
+
"help": "Number of updates steps to accumulate before performing an update pass."
|
| 210 |
},
|
| 211 |
)
|
| 212 |
|
| 213 |
learning_rate: float = field(
|
| 214 |
default=5e-5, metadata={"help": "The initial learning rate."}
|
| 215 |
)
|
| 216 |
+
optim: str = field(
|
| 217 |
+
default="distributed_shampoo",
|
| 218 |
+
metadata={
|
| 219 |
+
"help": 'The optimizer to use. Can be "distributed_shampoo" (default), "adam" or "adafactor"'
|
| 220 |
+
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 221 |
)
|
| 222 |
+
weight_decay: float = field(default=None, metadata={"help": "Weight decay."})
|
| 223 |
+
beta1: float = field(
|
| 224 |
+
default=0.9,
|
| 225 |
+
metadata={"help": "Beta1 for Adam & Distributed Shampoo."},
|
| 226 |
)
|
| 227 |
+
beta2: float = field(
|
| 228 |
+
default=0.999,
|
| 229 |
+
metadata={"help": "Beta2 for for Adam & Distributed Shampoo."},
|
| 230 |
)
|
| 231 |
adam_epsilon: float = field(
|
| 232 |
default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."}
|
|
|
|
| 234 |
max_grad_norm: float = field(
|
| 235 |
default=1.0, metadata={"help": "Max gradient norm for Adafactor."}
|
| 236 |
)
|
| 237 |
+
block_size: int = field(
|
| 238 |
+
default=1024,
|
| 239 |
+
metadata={"help": "Chunked size for large layers with Distributed Shampoo."},
|
| 240 |
+
)
|
| 241 |
+
preconditioning_compute_steps: int = field(
|
| 242 |
+
default=10, metadata={"help": "Number of steps to update preconditioner."}
|
| 243 |
+
)
|
| 244 |
+
skip_preconditioning_dim_size_gt: int = field(
|
| 245 |
+
default=4096,
|
| 246 |
+
metadata={"help": "Max size for preconditioning with Distributed Shampoo."},
|
| 247 |
+
)
|
| 248 |
+
optim_quantized: bool = field(
|
| 249 |
+
default=False,
|
| 250 |
+
metadata={
|
| 251 |
+
"help": "Whether to quantize optimizer (only supported with Distributed Shampoo)."
|
| 252 |
+
},
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
lr_decay: str = field(
|
| 256 |
+
default=None,
|
| 257 |
+
metadata={
|
| 258 |
+
"help": "Decay to be used in the learning rate scheduler. Can be None (default), linear or exponential."
|
| 259 |
+
},
|
| 260 |
+
)
|
| 261 |
+
lr_transition_steps: int = field(
|
| 262 |
+
default=None,
|
| 263 |
+
metadata={
|
| 264 |
+
"help": "Number of transition steps associated with learning rate decay when using exponential decay."
|
| 265 |
+
},
|
| 266 |
+
)
|
| 267 |
+
lr_decay_rate: float = field(
|
| 268 |
+
default=None,
|
| 269 |
+
metadata={
|
| 270 |
+
"help": "Decay rate associated with learning rate when using exponential decay."
|
| 271 |
+
},
|
| 272 |
+
)
|
| 273 |
+
lr_staircase: bool = field(
|
| 274 |
default=False,
|
| 275 |
+
metadata={
|
| 276 |
+
"help": "Whether to use staircase or continuous learning rate when using exponential decay."
|
| 277 |
+
},
|
| 278 |
)
|
| 279 |
|
| 280 |
num_train_epochs: float = field(
|
|
|
|
| 305 |
},
|
| 306 |
)
|
| 307 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 308 |
resume_from_checkpoint: Optional[str] = field(
|
| 309 |
default=None,
|
| 310 |
metadata={"help": "Reference to a wandb artifact for resuming training."},
|
| 311 |
)
|
| 312 |
|
| 313 |
+
def __post_init__(self):
|
| 314 |
+
assert self.optim in [
|
| 315 |
+
"distributed_shampoo",
|
| 316 |
+
"adam",
|
| 317 |
+
"adafactor",
|
| 318 |
+
], f"Selected optimizer not supported: {self.optim}"
|
| 319 |
+
|
| 320 |
|
| 321 |
class TrainState(train_state.TrainState):
|
| 322 |
dropout_rng: jnp.ndarray = None
|
|
|
|
| 347 |
)
|
| 348 |
|
| 349 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 350 |
class MetricsLogger:
|
| 351 |
def __init__(self, state):
|
| 352 |
self.step = state.step
|
|
|
|
| 540 |
num_params = model.num_params
|
| 541 |
|
| 542 |
# Create learning rate schedule
|
| 543 |
+
def create_learning_rate_fn() -> Callable[[int], jnp.array]:
|
| 544 |
+
"""Create the learning rate function."""
|
| 545 |
+
warmup_fn = optax.linear_schedule(
|
| 546 |
+
init_value=0.0,
|
| 547 |
+
end_value=training_args.learning_rate,
|
| 548 |
+
transition_steps=training_args.warmup_steps,
|
| 549 |
+
)
|
| 550 |
+
if training_args.lr_decay is None:
|
| 551 |
+
return warmup_fn
|
| 552 |
+
elif training_args.lr_decay == "linear":
|
| 553 |
+
assert (
|
| 554 |
+
num_train_steps is not None
|
| 555 |
+
), "linear decay requires knowing the dataset length"
|
| 556 |
+
decay_fn = optax.linear_schedule(
|
| 557 |
+
init_value=training_args.learning_rate,
|
| 558 |
+
end_value=0,
|
| 559 |
+
transition_steps=num_train_steps - training_args.warmup_steps,
|
| 560 |
+
)
|
| 561 |
+
elif training_args.lr_decay == "exponential":
|
| 562 |
+
decay_fn = optax.exponential_decay(
|
| 563 |
+
init_value=training_args.learning_rate,
|
| 564 |
+
transition_steps=training_args.lr_transition_steps,
|
| 565 |
+
decay_rate=training_args.lr_decay_rate,
|
| 566 |
+
staircase=training_args.lr_staircase,
|
| 567 |
+
)
|
| 568 |
+
schedule_fn = optax.join_schedules(
|
| 569 |
+
schedules=[warmup_fn, decay_fn], boundaries=[training_args.warmup_steps]
|
| 570 |
+
)
|
| 571 |
+
return schedule_fn
|
| 572 |
+
|
| 573 |
+
learning_rate_fn = create_learning_rate_fn()
|
| 574 |
|
| 575 |
# We use Optax's "masking" functionality to not apply weight decay
|
| 576 |
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
|
|
|
|
| 594 |
return traverse_util.unflatten_dict(flat_mask)
|
| 595 |
|
| 596 |
# create adam optimizer
|
| 597 |
+
if training_args.optim == "distributed_shampoo":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 598 |
# parameters from https://github.com/tensorflow/lingvo/blob/03ee9d7cd50764b0424c7c863733c91fc0b053ec/lingvo/jax/optimizers.py#L729
|
| 599 |
# Notes:
|
| 600 |
+
# - mask for weight decay is not implemented
|
| 601 |
optimizer = distributed_shampoo(
|
| 602 |
learning_rate_fn,
|
| 603 |
+
block_size=training_args.block_size,
|
| 604 |
+
beta1=training_args.beta1,
|
| 605 |
+
beta2=training_args.beta2,
|
| 606 |
diagonal_epsilon=1e-10,
|
| 607 |
matrix_epsilon=1e-8,
|
| 608 |
+
weight_decay=training_args.weight_decay
|
| 609 |
+
if training_args.weight_decay is not None
|
| 610 |
+
else 0.0,
|
| 611 |
+
start_preconditioning_step=training_args.warmup_steps,
|
| 612 |
+
preconditioning_compute_steps=training_args.preconditioning_compute_steps,
|
| 613 |
statistics_compute_steps=1,
|
| 614 |
best_effort_shape_interpretation=True,
|
| 615 |
graft_type=GraftingType.RMSPROP_NORMALIZED,
|
|
|
|
| 618 |
batch_axis_name="batch",
|
| 619 |
inverse_failure_threshold=0.1,
|
| 620 |
moving_average_for_momentum=True,
|
| 621 |
+
skip_preconditioning_dim_size_gt=training_args.skip_preconditioning_dim_size_gt,
|
| 622 |
clip_by_scaled_gradient_norm=None,
|
| 623 |
precision=jax.lax.Precision.HIGHEST,
|
| 624 |
+
best_effort_memory_usage_reduction=training_args.optim_quantized,
|
| 625 |
)
|
| 626 |
|
| 627 |
+
elif training_args.optim == "adam":
|
| 628 |
optimizer = optax.adamw(
|
| 629 |
learning_rate=learning_rate_fn,
|
| 630 |
+
b1=training_args.beta1,
|
| 631 |
+
b2=training_args.beta2,
|
| 632 |
eps=training_args.adam_epsilon,
|
| 633 |
weight_decay=training_args.weight_decay
|
| 634 |
if training_args.weight_decay is not None
|
| 635 |
else 0.0,
|
| 636 |
mask=decay_mask_fn,
|
| 637 |
)
|
| 638 |
+
elif training_args.optim == "adafactor":
|
| 639 |
+
# We use the default parameters here to initialize adafactor,
|
| 640 |
+
# For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
|
| 641 |
+
optimizer = optax.adafactor(
|
| 642 |
+
learning_rate=learning_rate_fn,
|
| 643 |
+
weight_decay_rate=training_args.weight_decay,
|
| 644 |
+
weight_decay_mask=decay_mask_fn,
|
| 645 |
+
clipping_threshold=training_args.max_grad_norm,
|
| 646 |
+
)
|
| 647 |
|
| 648 |
# add gradient accumulation
|
| 649 |
if training_args.gradient_accumulation_steps > 1:
|
|
|
|
| 859 |
|
| 860 |
wandb.run.log_artifact(artifact)
|
| 861 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 862 |
# init variables
|
| 863 |
last_time = time.perf_counter()
|
| 864 |
train_metrics = None
|
|
|
|
| 869 |
metrics_logger.log({"train/epoch": epoch}, step=unreplicate(state.step))
|
| 870 |
|
| 871 |
# Generate an epoch by shuffling sampling indices from the train dataset
|
| 872 |
+
train_loader = dataset.dataloader("train", train_batch_size, epoch)
|
| 873 |
# train
|
| 874 |
for batch in tqdm(
|
| 875 |
train_loader,
|